Spaces:
Runtime error
Runtime error
first upload
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +2 -1
- app.py +59 -0
- artst/__init__.py +1 -0
- artst/__pycache__/__init__.cpython-38.pyc +0 -0
- artst/__pycache__/sequence_generator.cpython-38.pyc +0 -0
- artst/criterions/__init__.py +10 -0
- artst/criterions/__pycache__/__init__.cpython-38.pyc +0 -0
- artst/criterions/__pycache__/artst_criterion.cpython-38.pyc +0 -0
- artst/criterions/__pycache__/speech_pretrain_criterion.cpython-38.pyc +0 -0
- artst/criterions/__pycache__/speech_to_text_loss.cpython-38.pyc +0 -0
- artst/criterions/__pycache__/speecht5_criterion.cpython-38.pyc +0 -0
- artst/criterions/__pycache__/text_pretrain_criterion.cpython-38.pyc +0 -0
- artst/criterions/__pycache__/text_to_speech_loss.cpython-38.pyc +0 -0
- artst/criterions/artst_criterion.py +443 -0
- artst/criterions/speech_pretrain_criterion.py +265 -0
- artst/criterions/speech_to_text_loss.py +473 -0
- artst/criterions/text_pretrain_criterion.py +142 -0
- artst/criterions/text_to_speech_loss.py +425 -0
- artst/data/__init__.py +0 -0
- artst/data/__pycache__/__init__.cpython-38.pyc +0 -0
- artst/data/__pycache__/multitask_dataset.cpython-38.pyc +0 -0
- artst/data/__pycache__/speech_dataset.cpython-38.pyc +0 -0
- artst/data/__pycache__/speech_to_class_dataset.cpython-38.pyc +0 -0
- artst/data/__pycache__/speech_to_speech_dataset.cpython-38.pyc +0 -0
- artst/data/__pycache__/speech_to_text_dataset.cpython-38.pyc +0 -0
- artst/data/__pycache__/text_dataset.cpython-38.pyc +0 -0
- artst/data/__pycache__/text_to_speech_dataset.cpython-38.pyc +0 -0
- artst/data/multitask_dataset.py +263 -0
- artst/data/speech_dataset.py +475 -0
- artst/data/speech_to_class_dataset.py +260 -0
- artst/data/speech_to_speech_dataset.py +280 -0
- artst/data/speech_to_text_dataset.py +298 -0
- artst/data/text_dataset.py +474 -0
- artst/data/text_to_speech_dataset.py +344 -0
- artst/models/__init__.py +2 -0
- artst/models/__pycache__/__init__.cpython-38.pyc +0 -0
- artst/models/__pycache__/artst.cpython-38.pyc +0 -0
- artst/models/__pycache__/speecht5.cpython-38.pyc +0 -0
- artst/models/__pycache__/t5_transformer_lm.cpython-38.pyc +0 -0
- artst/models/artst.py +1448 -0
- artst/models/modules/__init__.py +0 -0
- artst/models/modules/__pycache__/__init__.cpython-38.pyc +0 -0
- artst/models/modules/__pycache__/decoder.cpython-38.pyc +0 -0
- artst/models/modules/__pycache__/encoder.cpython-38.pyc +0 -0
- artst/models/modules/__pycache__/multihead_attention.cpython-38.pyc +0 -0
- artst/models/modules/__pycache__/speaker_decoder_postnet.cpython-38.pyc +0 -0
- artst/models/modules/__pycache__/speech_decoder_postnet.cpython-38.pyc +0 -0
- artst/models/modules/__pycache__/speech_decoder_prenet.cpython-38.pyc +0 -0
- artst/models/modules/__pycache__/speech_encoder_postnet.cpython-38.pyc +0 -0
- artst/models/modules/__pycache__/speech_encoder_prenet.cpython-38.pyc +0 -0
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🔥
|
4 |
colorFrom: yellow
|
5 |
colorTo: gray
|
@@ -7,6 +7,7 @@ sdk: gradio
|
|
7 |
sdk_version: 4.7.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: ArtstTTS
|
3 |
emoji: 🔥
|
4 |
colorFrom: yellow
|
5 |
colorTo: gray
|
|
|
7 |
sdk_version: 4.7.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
python_version: 3.8.2
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
import os.path as op
|
5 |
+
import pyarabic.araby as araby
|
6 |
+
|
7 |
+
from artst.tasks.artst import ArTSTTask
|
8 |
+
from transformers import SpeechT5HifiGan
|
9 |
+
from artst.models.artst import ArTSTTransformerModel
|
10 |
+
from fairseq.tasks.hubert_pretraining import LabelEncoder
|
11 |
+
from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
|
12 |
+
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
|
15 |
+
WORK_DIR = os.getcwd()
|
16 |
+
checkpoint = torch.load('ckpts/clartts_tts.pt')
|
17 |
+
checkpoint['cfg']['task'].t5_task = 't2s'
|
18 |
+
task = ArTSTTask.setup_task(checkpoint['cfg']['task'])
|
19 |
+
|
20 |
+
emb_path='embs/clartts.npy'
|
21 |
+
model = ArTSTTransformerModel.build_model(checkpoint['cfg']['model'], task)
|
22 |
+
model.load_state_dict(checkpoint['model'])
|
23 |
+
|
24 |
+
checkpoint['cfg']['task'].bpe_tokenizer = task.build_bpe(checkpoint['cfg']['model'])
|
25 |
+
tokenizer = checkpoint['cfg']['task'].bpe_tokenizer
|
26 |
+
|
27 |
+
processor = LabelEncoder(task.dicts['text'])
|
28 |
+
|
29 |
+
vocoder = SpeechT5HifiGan.from_pretrained('microsoft/speecht5_hifigan').to(device)
|
30 |
+
|
31 |
+
def get_embs(emb_path):
|
32 |
+
spkembs = get_features_or_waveform(emb_path)
|
33 |
+
spkembs = torch.from_numpy(spkembs).float().unsqueeze(0)
|
34 |
+
return spkembs
|
35 |
+
|
36 |
+
def process_text(text):
|
37 |
+
text = araby.strip_diacritics(text)
|
38 |
+
return processor(tokenizer.encode(text)).reshape(1, -1)
|
39 |
+
|
40 |
+
net_input = {}
|
41 |
+
|
42 |
+
def inference(text, spkr=emb_path):
|
43 |
+
net_input['src_tokens'] = process_text(text)
|
44 |
+
net_input['spkembs'] = get_embs(spkr)
|
45 |
+
outs, _, attn = task.generate_speech(
|
46 |
+
[model],
|
47 |
+
net_input,
|
48 |
+
)
|
49 |
+
with torch.no_grad():
|
50 |
+
gen_audio = vocoder(outs.to(device))
|
51 |
+
return (16000,gen_audio.cpu().numpy())
|
52 |
+
|
53 |
+
text_box = gr.Textbox(max_lines=2, label="Arabic Text")
|
54 |
+
out = gr.Audio(label="Synthesized Audio", type="numpy")
|
55 |
+
demo = gr.Interface(inference, \
|
56 |
+
inputs=text_box, outputs=out, title="ArTST")
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
demo.launch(share=True)
|
artst/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import data, tasks, criterions, models # noqa
|
artst/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (218 Bytes). View file
|
|
artst/__pycache__/sequence_generator.cpython-38.pyc
ADDED
Binary file (26.2 kB). View file
|
|
artst/criterions/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
6 |
+
if file.endswith(".py") and not file.startswith("_"):
|
7 |
+
criterion_name = file[: file.find(".py")]
|
8 |
+
importlib.import_module(
|
9 |
+
"artst.criterions." + criterion_name
|
10 |
+
)
|
artst/criterions/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (395 Bytes). View file
|
|
artst/criterions/__pycache__/artst_criterion.cpython-38.pyc
ADDED
Binary file (16.3 kB). View file
|
|
artst/criterions/__pycache__/speech_pretrain_criterion.cpython-38.pyc
ADDED
Binary file (9.08 kB). View file
|
|
artst/criterions/__pycache__/speech_to_text_loss.cpython-38.pyc
ADDED
Binary file (12.8 kB). View file
|
|
artst/criterions/__pycache__/speecht5_criterion.cpython-38.pyc
ADDED
Binary file (16.3 kB). View file
|
|
artst/criterions/__pycache__/text_pretrain_criterion.cpython-38.pyc
ADDED
Binary file (5.55 kB). View file
|
|
artst/criterions/__pycache__/text_to_speech_loss.cpython-38.pyc
ADDED
Binary file (14 kB). View file
|
|
artst/criterions/artst_criterion.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import re
|
9 |
+
from dataclasses import dataclass
|
10 |
+
|
11 |
+
import math
|
12 |
+
from fairseq import metrics, utils
|
13 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
14 |
+
from artst.criterions.text_to_speech_loss import TexttoSpeechLoss
|
15 |
+
from artst.criterions.text_pretrain_criterion import TextPretrainCriterion, TextPretrainCriterionConfig
|
16 |
+
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterionConfig
|
17 |
+
from artst.criterions.speech_pretrain_criterion import SpeechPretrainCriterion, SpeechPretrainCriterionConfig
|
18 |
+
from artst.criterions.speech_to_text_loss import SpeechtoTextLoss, SpeechtoTextLossConfig
|
19 |
+
from fairseq.logging.meters import safe_round
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class ArTSTCriterionConfig(
|
23 |
+
LabelSmoothedCrossEntropyCriterionConfig,
|
24 |
+
TextPretrainCriterionConfig,
|
25 |
+
SpeechPretrainCriterionConfig,
|
26 |
+
SpeechtoTextLossConfig
|
27 |
+
):
|
28 |
+
pass
|
29 |
+
|
30 |
+
@register_criterion(
|
31 |
+
"artst", dataclass=ArTSTCriterionConfig
|
32 |
+
)
|
33 |
+
class ArTSTCriterion(FairseqCriterion):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
task,
|
37 |
+
sentence_avg,
|
38 |
+
label_smoothing,
|
39 |
+
pred_masked_weight,
|
40 |
+
pred_nomask_weight,
|
41 |
+
loss_weights=None,
|
42 |
+
log_keys=None,
|
43 |
+
ignore_prefix_size=0,
|
44 |
+
report_accuracy=False,
|
45 |
+
use_masking=True,
|
46 |
+
use_weighted_masking=False,
|
47 |
+
loss_type="L1",
|
48 |
+
bce_pos_weight=5.0,
|
49 |
+
bce_loss_lambda=1.0,
|
50 |
+
use_guided_attn_loss=False,
|
51 |
+
num_heads_applied_guided_attn=2,
|
52 |
+
ce_weight=1.0,
|
53 |
+
ctc_weight=0.0,
|
54 |
+
hubert_weight=1.0,
|
55 |
+
dec_weight=1.0,
|
56 |
+
bart_weight=1.0,
|
57 |
+
):
|
58 |
+
super().__init__(task)
|
59 |
+
self.speech_criterion = TexttoSpeechLoss(
|
60 |
+
task,
|
61 |
+
sentence_avg,
|
62 |
+
use_masking,
|
63 |
+
use_weighted_masking,
|
64 |
+
loss_type,
|
65 |
+
bce_pos_weight,
|
66 |
+
bce_loss_lambda,
|
67 |
+
use_guided_attn_loss,
|
68 |
+
num_heads_applied_guided_attn=num_heads_applied_guided_attn,
|
69 |
+
)
|
70 |
+
self.text_criterion = SpeechtoTextLoss(
|
71 |
+
SpeechtoTextLossConfig,
|
72 |
+
task,
|
73 |
+
sentence_avg,
|
74 |
+
label_smoothing,
|
75 |
+
ignore_prefix_size,
|
76 |
+
report_accuracy,
|
77 |
+
ce_weight,
|
78 |
+
ctc_weight
|
79 |
+
)
|
80 |
+
self.text_pretrain_criterion = TextPretrainCriterion(
|
81 |
+
task,
|
82 |
+
sentence_avg,
|
83 |
+
bart_weight,
|
84 |
+
loss_weights,
|
85 |
+
)
|
86 |
+
self.speech_pretrain_criterion = SpeechPretrainCriterion(
|
87 |
+
task,
|
88 |
+
sentence_avg,
|
89 |
+
pred_masked_weight,
|
90 |
+
pred_nomask_weight,
|
91 |
+
loss_weights,
|
92 |
+
log_keys,
|
93 |
+
use_masking,
|
94 |
+
use_weighted_masking,
|
95 |
+
loss_type,
|
96 |
+
bce_pos_weight,
|
97 |
+
hubert_weight,
|
98 |
+
dec_weight
|
99 |
+
)
|
100 |
+
|
101 |
+
def forward(self, model, sample, reduce=True):
|
102 |
+
"""Compute the loss for the given sample.
|
103 |
+
|
104 |
+
Returns a tuple with three elements:
|
105 |
+
1) the loss
|
106 |
+
2) the sample size, which is used as the denominator for the gradient
|
107 |
+
3) logging outputs to display while training
|
108 |
+
"""
|
109 |
+
|
110 |
+
task_name = sample['task_name']
|
111 |
+
if task_name == 's2t' or task_name == 's2c':
|
112 |
+
return self.text_criterion(model, sample, reduce)
|
113 |
+
elif task_name == 't2s' or task_name == 's2s':
|
114 |
+
return self.speech_criterion(model, sample)
|
115 |
+
elif task_name == 'text_pretrain':
|
116 |
+
return self.text_pretrain_criterion(model, sample, reduce)
|
117 |
+
elif task_name == 'speech_pretrain':
|
118 |
+
return self.speech_pretrain_criterion(model, sample, reduce)
|
119 |
+
|
120 |
+
@classmethod
|
121 |
+
def reduce_metrics(cls, logging_outputs):
|
122 |
+
"""Aggregate logging outputs from data parallel training."""
|
123 |
+
logging_outputs_dict = {}
|
124 |
+
for logging_output in logging_outputs:
|
125 |
+
for task_name in logging_output:
|
126 |
+
if task_name not in ['s2t', 't2s', 's2c', 's2s', 'text_pretrain', 'speech_pretrain']:
|
127 |
+
continue
|
128 |
+
|
129 |
+
if task_name not in logging_outputs_dict:
|
130 |
+
logging_outputs_dict[task_name] = []
|
131 |
+
logging_outputs_dict[task_name].append(logging_output[task_name])
|
132 |
+
|
133 |
+
for task_name in logging_outputs_dict:
|
134 |
+
if task_name == 's2t':
|
135 |
+
# LabelSmoothedCrossEntropyCriterion.reduce_metrics([logging_output['s2t'] for logging_output in logging_outputs])
|
136 |
+
s2t_logging_output = logging_outputs_dict[task_name]
|
137 |
+
# s2t_sum = sum(log.get("ce_loss", 0) for log in logging_outputs)
|
138 |
+
loss_sum = sum(log.get("loss", 0) for log in s2t_logging_output)
|
139 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2t_logging_output)
|
140 |
+
ntokens = sum(log.get("ntokens", 0) for log in s2t_logging_output)
|
141 |
+
ce_loss_sum = sum(log.get("ce_loss", 0) for log in s2t_logging_output)
|
142 |
+
ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in s2t_logging_output)
|
143 |
+
|
144 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in s2t_logging_output))
|
145 |
+
metrics.log_scalar(
|
146 |
+
"s2t_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3
|
147 |
+
)
|
148 |
+
|
149 |
+
metrics.log_scalar(
|
150 |
+
"s2t_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
|
151 |
+
)
|
152 |
+
metrics.log_derived(
|
153 |
+
"s2t_ppl", lambda meters: utils.get_perplexity(meters["s2t_nll_loss"].avg, 2)
|
154 |
+
)
|
155 |
+
metrics.log_scalar(
|
156 |
+
"ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
|
157 |
+
)
|
158 |
+
metrics.log_scalar(
|
159 |
+
"ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3
|
160 |
+
)
|
161 |
+
|
162 |
+
total = utils.item(sum(log.get("total", 0) for log in s2t_logging_output))
|
163 |
+
if total > 0:
|
164 |
+
metrics.log_scalar("s2t_total", total)
|
165 |
+
n_correct = utils.item(
|
166 |
+
sum(log.get("n_correct", 0) for log in s2t_logging_output)
|
167 |
+
)
|
168 |
+
metrics.log_scalar("s2t_n_correct", n_correct)
|
169 |
+
metrics.log_derived(
|
170 |
+
"s2t_accuracy",
|
171 |
+
lambda meters: round(
|
172 |
+
meters["s2t_n_correct"].sum * 100.0 / meters["s2t_total"].sum, 3
|
173 |
+
)
|
174 |
+
if meters["s2t_total"].sum > 0
|
175 |
+
else float("nan"),
|
176 |
+
2
|
177 |
+
)
|
178 |
+
c_errors = sum(log.get("c_errors", 0) for log in s2t_logging_output)
|
179 |
+
metrics.log_scalar("_c_errors", c_errors)
|
180 |
+
c_total = sum(log.get("c_total", 0) for log in s2t_logging_output)
|
181 |
+
metrics.log_scalar("_c_total", c_total)
|
182 |
+
w_errors = sum(log.get("w_errors", 0) for log in s2t_logging_output)
|
183 |
+
metrics.log_scalar("_w_errors", w_errors)
|
184 |
+
wv_errors = sum(log.get("wv_errors", 0) for log in s2t_logging_output)
|
185 |
+
metrics.log_scalar("_wv_errors", wv_errors)
|
186 |
+
w_total = sum(log.get("w_total", 0) for log in s2t_logging_output)
|
187 |
+
metrics.log_scalar("_w_total", w_total)
|
188 |
+
if c_total > 0:
|
189 |
+
metrics.log_derived(
|
190 |
+
"uer",
|
191 |
+
lambda meters: safe_round(
|
192 |
+
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
|
193 |
+
)
|
194 |
+
if meters["_c_total"].sum > 0
|
195 |
+
else float("nan"),
|
196 |
+
)
|
197 |
+
if w_total > 0:
|
198 |
+
metrics.log_derived(
|
199 |
+
"wer",
|
200 |
+
lambda meters: safe_round(
|
201 |
+
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
202 |
+
)
|
203 |
+
if meters["_w_total"].sum > 0
|
204 |
+
else float("nan"),
|
205 |
+
)
|
206 |
+
metrics.log_derived(
|
207 |
+
"raw_wer",
|
208 |
+
lambda meters: safe_round(
|
209 |
+
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
210 |
+
)
|
211 |
+
if meters["_w_total"].sum > 0
|
212 |
+
else float("nan"),
|
213 |
+
)
|
214 |
+
|
215 |
+
if task_name == 't2s':
|
216 |
+
# TTSLossCriterion.reduce_metrics([logging_output['t2s'] for logging_output in logging_outputs])
|
217 |
+
# t2s_sum = sum(log.get("speech_loss", 0) for log in logging_outputs)
|
218 |
+
t2s_logging_output = logging_outputs_dict[task_name]
|
219 |
+
loss_sum = sum(log.get("loss", 0) for log in t2s_logging_output)
|
220 |
+
l1_loss_sum = sum(log.get("l1_loss", 0) for log in t2s_logging_output)
|
221 |
+
l2_loss_sum = sum(log.get("l2_loss", 0) for log in t2s_logging_output)
|
222 |
+
bce_loss_sum = sum(log.get("bce_loss", 0) for log in t2s_logging_output)
|
223 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in t2s_logging_output))
|
224 |
+
metrics.log_scalar(
|
225 |
+
"t2s_loss", loss_sum / sample_size, sample_size, 1, round=5
|
226 |
+
)
|
227 |
+
encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in t2s_logging_output)
|
228 |
+
decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in t2s_logging_output)
|
229 |
+
ngpu = sum(log.get("ngpu", 0) for log in t2s_logging_output)
|
230 |
+
|
231 |
+
metrics.log_scalar(
|
232 |
+
"t2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
|
233 |
+
)
|
234 |
+
metrics.log_scalar(
|
235 |
+
"t2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
|
236 |
+
)
|
237 |
+
metrics.log_scalar(
|
238 |
+
"t2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
|
239 |
+
)
|
240 |
+
metrics.log_scalar(
|
241 |
+
"t2s_encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5
|
242 |
+
)
|
243 |
+
metrics.log_scalar(
|
244 |
+
"t2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
|
245 |
+
)
|
246 |
+
|
247 |
+
if "enc_dec_attn_loss" in t2s_logging_output[0]:
|
248 |
+
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in t2s_logging_output)
|
249 |
+
metrics.log_scalar(
|
250 |
+
"t2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
|
251 |
+
)
|
252 |
+
|
253 |
+
if task_name == 's2c':
|
254 |
+
s2c_logging_output = logging_outputs_dict[task_name]
|
255 |
+
loss_sum = sum(log.get("loss", 0) for log in s2c_logging_output)
|
256 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in s2c_logging_output)
|
257 |
+
ntokens = sum(log.get("ntokens", 0) for log in s2c_logging_output)
|
258 |
+
|
259 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in s2c_logging_output))
|
260 |
+
metrics.log_scalar(
|
261 |
+
"s2c_loss", loss_sum / sample_size / math.log(2), sample_size, 1, round=3
|
262 |
+
)
|
263 |
+
|
264 |
+
metrics.log_scalar(
|
265 |
+
"s2c_nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
|
266 |
+
)
|
267 |
+
|
268 |
+
total = utils.item(sum(log.get("total", 0) for log in s2c_logging_output))
|
269 |
+
if total > 0:
|
270 |
+
metrics.log_scalar("s2c_total", total)
|
271 |
+
n_correct = utils.item(sum(log.get("n_correct", 0) for log in s2c_logging_output))
|
272 |
+
metrics.log_scalar("s2c_n_correct", n_correct)
|
273 |
+
metrics.log_derived(
|
274 |
+
"s2c_accuracy",
|
275 |
+
lambda meters: round(
|
276 |
+
meters["s2c_n_correct"].sum * 100.0 / meters["s2c_total"].sum, 3
|
277 |
+
)
|
278 |
+
if meters["s2c_total"].sum > 0
|
279 |
+
else float("nan"),
|
280 |
+
2
|
281 |
+
)
|
282 |
+
|
283 |
+
if task_name == 's2s':
|
284 |
+
s2s_logging_output = logging_outputs_dict[task_name]
|
285 |
+
loss_sum = sum(log.get("loss", 0) for log in s2s_logging_output)
|
286 |
+
l1_loss_sum = sum(log.get("l1_loss", 0) for log in s2s_logging_output)
|
287 |
+
l2_loss_sum = sum(log.get("l2_loss", 0) for log in s2s_logging_output)
|
288 |
+
bce_loss_sum = sum(log.get("bce_loss", 0) for log in s2s_logging_output)
|
289 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in s2s_logging_output))
|
290 |
+
metrics.log_scalar(
|
291 |
+
"s2s_loss", loss_sum / sample_size, sample_size, 1, round=5
|
292 |
+
)
|
293 |
+
encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in s2s_logging_output)
|
294 |
+
decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in s2s_logging_output)
|
295 |
+
ngpu = sum(log.get("ngpu", 0) for log in s2s_logging_output)
|
296 |
+
|
297 |
+
metrics.log_scalar(
|
298 |
+
"s2s_l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
|
299 |
+
)
|
300 |
+
metrics.log_scalar(
|
301 |
+
"s2s_l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
|
302 |
+
)
|
303 |
+
metrics.log_scalar(
|
304 |
+
"s2s_bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
|
305 |
+
)
|
306 |
+
metrics.log_scalar(
|
307 |
+
"s2s_decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
|
308 |
+
)
|
309 |
+
|
310 |
+
if "enc_dec_attn_loss" in s2s_logging_output[0]:
|
311 |
+
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in s2s_logging_output)
|
312 |
+
metrics.log_scalar(
|
313 |
+
"s2s_enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
|
314 |
+
)
|
315 |
+
|
316 |
+
if task_name == 'text_pretrain':
|
317 |
+
bart_logging_output = logging_outputs_dict[task_name]
|
318 |
+
loss_sum = sum(log.get("loss", 0) for log in bart_logging_output)
|
319 |
+
ntokens = sum(log.get("ntokens", 0) for log in bart_logging_output)
|
320 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in bart_logging_output))
|
321 |
+
bart_loss_sum = sum(log.get("bart_loss", 0) for log in bart_logging_output)
|
322 |
+
|
323 |
+
# we divide by log(2) to convert the loss from base e to base 2
|
324 |
+
metrics.log_scalar(
|
325 |
+
"text_loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
326 |
+
)
|
327 |
+
metrics.log_scalar(
|
328 |
+
"bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
|
329 |
+
)
|
330 |
+
if sample_size != ntokens:
|
331 |
+
metrics.log_scalar(
|
332 |
+
"bart_nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3
|
333 |
+
)
|
334 |
+
metrics.log_derived(
|
335 |
+
"bart_ppl", lambda meters: utils.get_perplexity(meters["bart_nll_loss"].avg)
|
336 |
+
)
|
337 |
+
else:
|
338 |
+
metrics.log_derived(
|
339 |
+
"bart_ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg)
|
340 |
+
)
|
341 |
+
metrics.log_scalar("bart_wpb", ntokens, priority=180, round=1)
|
342 |
+
|
343 |
+
val_prob_perplexity = 0
|
344 |
+
val_code_perplexity = 0
|
345 |
+
sample_size_pp = 0
|
346 |
+
count_log_cp = 0
|
347 |
+
for log in bart_logging_output:
|
348 |
+
if "loss_prob_perplexity" in log:
|
349 |
+
val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"]
|
350 |
+
sample_size_pp = sample_size_pp + log["sample_size"]
|
351 |
+
if "code_perplexity" in log:
|
352 |
+
val_code_perplexity = val_code_perplexity + log["code_perplexity"]
|
353 |
+
count_log_cp = count_log_cp + 1
|
354 |
+
if val_prob_perplexity > 0:
|
355 |
+
metrics.log_scalar("text_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3)
|
356 |
+
if val_code_perplexity > 0:
|
357 |
+
metrics.log_scalar("text_code_perplexity", val_code_perplexity / count_log_cp, round=3)
|
358 |
+
|
359 |
+
if task_name == 'speech_pretrain':
|
360 |
+
hubert_logging_output = logging_outputs_dict[task_name]
|
361 |
+
loss_sum = sum(log.get("loss", 0) for log in hubert_logging_output)
|
362 |
+
ntokens = sum(log.get("ntokens", 0) for log in hubert_logging_output)
|
363 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in hubert_logging_output))
|
364 |
+
dec_loss_sum = sum(log.get("dec_loss", 0) for log in hubert_logging_output)
|
365 |
+
l1_loss_sum = sum(log.get("l1_loss", 0) for log in hubert_logging_output)
|
366 |
+
l2_loss_sum = sum(log.get("l2_loss", 0) for log in hubert_logging_output)
|
367 |
+
bce_loss_sum = sum(log.get("bce_loss", 0) for log in hubert_logging_output)
|
368 |
+
ngpu = sum(log.get("ngpu", 0) for log in hubert_logging_output)
|
369 |
+
|
370 |
+
metrics.log_scalar("hubert_loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
|
371 |
+
if sample_size != ntokens:
|
372 |
+
metrics.log_scalar("hubert_nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
|
373 |
+
metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_nll_loss"].avg))
|
374 |
+
else:
|
375 |
+
metrics.log_derived("hubert_ppl", lambda meters: utils.get_perplexity(meters["hubert_loss"].avg))
|
376 |
+
|
377 |
+
counts = {}
|
378 |
+
for lk in hubert_logging_output[0].keys():
|
379 |
+
if lk.startswith("count_"):
|
380 |
+
val = sum(log[lk] for log in hubert_logging_output)
|
381 |
+
metrics.log_scalar("hubert_" + lk, val)
|
382 |
+
counts[lk] = val
|
383 |
+
|
384 |
+
for lk in hubert_logging_output[0].keys():
|
385 |
+
if lk.startswith("loss_") and lk != 'loss_prob_perplexity':
|
386 |
+
val = sum(log[lk] for log in hubert_logging_output)
|
387 |
+
metrics.log_scalar("hubert_" + lk, val / sample_size / math.log(2), round=3)
|
388 |
+
elif lk.startswith("correct_"):
|
389 |
+
val = sum(log[lk] for log in hubert_logging_output)
|
390 |
+
metrics.log_scalar("hubert_" + lk, val / counts[re.sub("correct", "count", lk)])
|
391 |
+
# elif lk == 'code_perplexity':
|
392 |
+
# val = sum(log[lk] for log in hubert_logging_output)
|
393 |
+
# metrics.log_scalar("hubert_" + lk, val / len(hubert_logging_output), round=3)
|
394 |
+
|
395 |
+
val_prob_perplexity = 0
|
396 |
+
val_code_perplexity = 0
|
397 |
+
sample_size_pp = 0
|
398 |
+
count_log_cp = 0
|
399 |
+
for log in hubert_logging_output:
|
400 |
+
if "loss_prob_perplexity" in log:
|
401 |
+
val_prob_perplexity = val_prob_perplexity + log["loss_prob_perplexity"]
|
402 |
+
sample_size_pp = sample_size_pp + log["sample_size"]
|
403 |
+
if "code_perplexity" in log:
|
404 |
+
val_code_perplexity = val_code_perplexity + log["code_perplexity"]
|
405 |
+
count_log_cp = count_log_cp + 1
|
406 |
+
if val_prob_perplexity > 0:
|
407 |
+
metrics.log_scalar("hubert_loss_prob_perplexity", val_prob_perplexity / sample_size_pp / math.log(2), round=3)
|
408 |
+
if val_code_perplexity > 0:
|
409 |
+
metrics.log_scalar("hubert_code_perplexity", val_code_perplexity / count_log_cp, round=3)
|
410 |
+
|
411 |
+
metrics.log_scalar(
|
412 |
+
"hubert_dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5
|
413 |
+
)
|
414 |
+
metrics.log_scalar(
|
415 |
+
"hubert_l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5
|
416 |
+
)
|
417 |
+
metrics.log_scalar(
|
418 |
+
"hubert_l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5
|
419 |
+
)
|
420 |
+
metrics.log_scalar(
|
421 |
+
"hubert_bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5
|
422 |
+
)
|
423 |
+
if "enc_dec_attn_loss" in hubert_logging_output[0]:
|
424 |
+
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in hubert_logging_output)
|
425 |
+
metrics.log_scalar(
|
426 |
+
"hubert_enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8
|
427 |
+
)
|
428 |
+
metrics.log_scalar("hubert_wpb", ntokens, priority=180, round=1)
|
429 |
+
|
430 |
+
loss = sum(log.get("loss", 0) for log in logging_outputs)
|
431 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs))
|
432 |
+
metrics.log_scalar(
|
433 |
+
"loss", loss / sample_size, sample_size, 1, round=5
|
434 |
+
)
|
435 |
+
|
436 |
+
@staticmethod
|
437 |
+
def logging_outputs_can_be_summed() -> bool:
|
438 |
+
"""
|
439 |
+
Whether the logging outputs returned by `forward` can be summed
|
440 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
441 |
+
to True will improves distributed training speed.
|
442 |
+
"""
|
443 |
+
return False
|
artst/criterions/speech_pretrain_criterion.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
import re
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
from typing import List, Optional
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from fairseq import metrics, utils
|
16 |
+
from fairseq.criterions import FairseqCriterion
|
17 |
+
from artst.criterions.text_to_speech_loss import TexttoSpeechLoss, TexttoSpeechLossConfig
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class SpeechPretrainCriterionConfig(TexttoSpeechLossConfig):
|
22 |
+
pred_masked_weight: float = field(
|
23 |
+
default=1.0,
|
24 |
+
metadata={"help": "weight for predictive loss for masked frames"},
|
25 |
+
)
|
26 |
+
pred_nomask_weight: float = field(
|
27 |
+
default=0.0,
|
28 |
+
metadata={"help": "weight for predictive loss for unmasked frames"},
|
29 |
+
)
|
30 |
+
loss_weights: Optional[List[float]] = field(
|
31 |
+
default_factory=lambda: [10,],
|
32 |
+
metadata={"help": "weights for additional loss terms (not first one)"},
|
33 |
+
)
|
34 |
+
log_keys: List[str] = field(
|
35 |
+
default_factory=lambda: [],
|
36 |
+
metadata={"help": "output keys to log"},
|
37 |
+
)
|
38 |
+
hubert_weight: float = field(
|
39 |
+
default=1.0,
|
40 |
+
metadata={"help": "weight of hubert loss"},
|
41 |
+
)
|
42 |
+
dec_weight: float = field(
|
43 |
+
default=1.0,
|
44 |
+
metadata={"help": "weight of decoder loss"},
|
45 |
+
)
|
46 |
+
|
47 |
+
|
48 |
+
class SpeechPretrainCriterion(FairseqCriterion):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
task,
|
52 |
+
sentence_avg,
|
53 |
+
pred_masked_weight,
|
54 |
+
pred_nomask_weight,
|
55 |
+
loss_weights=None,
|
56 |
+
log_keys=None,
|
57 |
+
use_masking=True,
|
58 |
+
use_weighted_masking=False,
|
59 |
+
loss_type="L1",
|
60 |
+
bce_pos_weight=5.0,
|
61 |
+
hubert_weight=1.0,
|
62 |
+
dec_weight=1.0,
|
63 |
+
):
|
64 |
+
super().__init__(task)
|
65 |
+
self.pred_masked_weight = pred_masked_weight
|
66 |
+
self.pred_nomask_weight = pred_nomask_weight
|
67 |
+
self.loss_weights = loss_weights
|
68 |
+
self.log_keys = [] if log_keys is None else log_keys
|
69 |
+
self.hubert_weight = hubert_weight
|
70 |
+
self.dec_weight = dec_weight
|
71 |
+
|
72 |
+
self.speech_criterion = TexttoSpeechLoss(
|
73 |
+
task,
|
74 |
+
sentence_avg,
|
75 |
+
use_masking,
|
76 |
+
use_weighted_masking,
|
77 |
+
loss_type,
|
78 |
+
bce_pos_weight,
|
79 |
+
)
|
80 |
+
|
81 |
+
def forward(self, model, sample, reduce=True, log_pred=False):
|
82 |
+
"""Compute the loss for the given sample.
|
83 |
+
Returns a tuple with three elements:
|
84 |
+
1) the loss
|
85 |
+
2) the sample size, which is used as the denominator for the gradient
|
86 |
+
3) logging outputs to display while training
|
87 |
+
"""
|
88 |
+
if self.dec_weight == 0:
|
89 |
+
sample["net_input"]["only_hubert"] = True
|
90 |
+
net_output, net_output_dec = model(target_list=sample["target_list"], **sample["net_input"])
|
91 |
+
loss = 0.
|
92 |
+
sample_size = 0
|
93 |
+
logging_output = {}
|
94 |
+
reduction = "sum" if reduce else "none"
|
95 |
+
|
96 |
+
loss_m_list = []
|
97 |
+
logp_m_list = model.get_logits(net_output, True)
|
98 |
+
targ_m_list = model.get_targets(None, net_output, True)
|
99 |
+
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
|
100 |
+
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
|
101 |
+
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
|
102 |
+
loss_m_list.append(loss_m)
|
103 |
+
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
|
104 |
+
if self.pred_masked_weight > 0:
|
105 |
+
loss += self.pred_masked_weight * sum(loss_m_list)
|
106 |
+
sample_size += targ_m_list[0].numel()
|
107 |
+
|
108 |
+
loss_u_list = []
|
109 |
+
logp_u_list = model.get_logits(net_output, False)
|
110 |
+
targ_u_list = model.get_targets(None, net_output, False)
|
111 |
+
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
|
112 |
+
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
|
113 |
+
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
|
114 |
+
loss_u_list.append(loss_u)
|
115 |
+
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
|
116 |
+
if self.pred_nomask_weight > 0:
|
117 |
+
loss += self.pred_nomask_weight * sum(loss_u_list)
|
118 |
+
sample_size += targ_u_list[0].numel()
|
119 |
+
|
120 |
+
if self.loss_weights is not None:
|
121 |
+
assert hasattr(model, "get_extra_losses")
|
122 |
+
extra_losses, names = model.get_extra_losses(net_output)
|
123 |
+
if torch.is_tensor(extra_losses):
|
124 |
+
extra_losses = [extra_losses]
|
125 |
+
names = [names]
|
126 |
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
127 |
+
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
128 |
+
if len(self.loss_weights) > len(extra_losses):
|
129 |
+
modified_loss_weight = self.loss_weights[:len(extra_losses)]
|
130 |
+
else:
|
131 |
+
modified_loss_weight = self.loss_weights
|
132 |
+
|
133 |
+
# assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
134 |
+
for p, n, coef in zip(extra_losses, names, modified_loss_weight):
|
135 |
+
# print(n + str(coef))
|
136 |
+
if coef != 0 and p is not None:
|
137 |
+
p = coef * p.float() * sample_size
|
138 |
+
loss += p
|
139 |
+
logging_output[f"loss_{n}"] = p.detach().item()
|
140 |
+
|
141 |
+
logging_output = {
|
142 |
+
"ntokens": sample_size,
|
143 |
+
"nsentences": sample["id"].numel(),
|
144 |
+
"sample_size": sample_size,
|
145 |
+
"ngpu": 1,
|
146 |
+
**logging_output,
|
147 |
+
}
|
148 |
+
|
149 |
+
if 'loss_prob_perplexity' in logging_output:
|
150 |
+
logging_output['code_perplexity'] = net_output['code_perplexity'].detach().item()
|
151 |
+
|
152 |
+
for lk in self.log_keys:
|
153 |
+
if lk in net_output:
|
154 |
+
logging_output[lk] = float((net_output[lk].item()))
|
155 |
+
|
156 |
+
def compute_correct(logits):
|
157 |
+
if logits.numel() == 0:
|
158 |
+
return 0, 0
|
159 |
+
else:
|
160 |
+
assert logits.dim() > 1, logits.shape
|
161 |
+
max = logits.argmax(-1) == 0
|
162 |
+
min = logits.argmin(-1) == 0
|
163 |
+
both = max & min
|
164 |
+
corr = max.long().sum().item() - both.long().sum().item()
|
165 |
+
count = max.numel()
|
166 |
+
return corr, count
|
167 |
+
|
168 |
+
with torch.no_grad():
|
169 |
+
for i, logp_m in enumerate(logp_m_list):
|
170 |
+
corr_m, count_m = compute_correct(logp_m)
|
171 |
+
logging_output[f"correct_m_{i}"] = corr_m
|
172 |
+
logging_output[f"count_m_{i}"] = count_m
|
173 |
+
|
174 |
+
for i, logp_u in enumerate(logp_u_list):
|
175 |
+
corr_u, count_u = compute_correct(logp_u)
|
176 |
+
logging_output[f"correct_u_{i}"] = corr_u
|
177 |
+
logging_output[f"count_u_{i}"] = count_u
|
178 |
+
|
179 |
+
if self.dec_weight == 0.0:
|
180 |
+
logging_output["loss"] = loss.item() if reduce else loss
|
181 |
+
return loss, sample_size, logging_output
|
182 |
+
|
183 |
+
# ## dec loss
|
184 |
+
dec_loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.speech_criterion.compute_loss(model, net_output_dec, sample)
|
185 |
+
|
186 |
+
# Log tts loss
|
187 |
+
logging_output['dec_loss'] = dec_loss.item()
|
188 |
+
logging_output['l1_loss'] = l1_loss.item()
|
189 |
+
logging_output['l2_loss'] = l2_loss.item()
|
190 |
+
logging_output['bce_loss'] = bce_loss.item()
|
191 |
+
if enc_dec_attn_loss is not None:
|
192 |
+
logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item()
|
193 |
+
|
194 |
+
loss = self.hubert_weight * loss + self.dec_weight * sample_size * dec_loss
|
195 |
+
logging_output["loss"] = loss.item() if reduce else loss
|
196 |
+
return loss, sample_size, logging_output
|
197 |
+
|
198 |
+
@staticmethod
|
199 |
+
def reduce_metrics(logging_outputs) -> None:
|
200 |
+
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
|
201 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
202 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
203 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
204 |
+
dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
|
205 |
+
l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs)
|
206 |
+
l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs)
|
207 |
+
bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs)
|
208 |
+
ngpu = sum(log.get("ngpu", 0) for log in logging_outputs)
|
209 |
+
|
210 |
+
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
|
211 |
+
if sample_size != ntokens:
|
212 |
+
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
|
213 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
|
214 |
+
else:
|
215 |
+
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
|
216 |
+
|
217 |
+
counts = {}
|
218 |
+
for lk in logging_outputs[0].keys():
|
219 |
+
if lk.startswith("count_"):
|
220 |
+
val = sum(log[lk] for log in logging_outputs)
|
221 |
+
metrics.log_scalar(lk, val)
|
222 |
+
counts[lk] = val
|
223 |
+
|
224 |
+
for lk in logging_outputs[0].keys():
|
225 |
+
if lk.startswith("loss_"):
|
226 |
+
val = sum(log[lk] for log in logging_outputs)
|
227 |
+
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
|
228 |
+
elif lk.startswith("correct_"):
|
229 |
+
val = sum(log[lk] for log in logging_outputs)
|
230 |
+
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
|
231 |
+
elif lk == 'code_perplexity':
|
232 |
+
val = sum(log[lk] for log in logging_outputs)
|
233 |
+
metrics.log_scalar(lk, val / len(logging_outputs), round=3)
|
234 |
+
|
235 |
+
metrics.log_scalar(
|
236 |
+
"dec_loss", dec_loss_sum / ngpu, sample_size, 2, round=5
|
237 |
+
)
|
238 |
+
metrics.log_scalar(
|
239 |
+
"l1_loss", l1_loss_sum / ngpu, sample_size, 2, round=5
|
240 |
+
)
|
241 |
+
metrics.log_scalar(
|
242 |
+
"l2_loss", l2_loss_sum / ngpu, sample_size, 2, round=5
|
243 |
+
)
|
244 |
+
metrics.log_scalar(
|
245 |
+
"bce_loss", bce_loss_sum / ngpu, sample_size, 2, round=5
|
246 |
+
)
|
247 |
+
if "enc_dec_attn_loss" in logging_outputs[0]:
|
248 |
+
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs)
|
249 |
+
metrics.log_scalar(
|
250 |
+
"enc_dec_attn_loss", enc_dec_attn_loss_sum / ngpu, sample_size, round=8
|
251 |
+
)
|
252 |
+
|
253 |
+
@staticmethod
|
254 |
+
def aggregate_logging_outputs(logging_outputs):
|
255 |
+
"""Aggregate logging outputs from data parallel training."""
|
256 |
+
raise NotImplementedError()
|
257 |
+
|
258 |
+
@staticmethod
|
259 |
+
def logging_outputs_can_be_summed() -> bool:
|
260 |
+
"""
|
261 |
+
Whether the logging outputs returned by `forward` can be summed
|
262 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
263 |
+
to True will improves distributed training speed.
|
264 |
+
"""
|
265 |
+
return False
|
artst/criterions/speech_to_text_loss.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
from argparse import Namespace
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
from omegaconf import II
|
12 |
+
from typing import Optional
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from fairseq import metrics, utils
|
17 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
18 |
+
from fairseq.dataclass import FairseqDataclass
|
19 |
+
from fairseq.data.data_utils import post_process
|
20 |
+
from fairseq.tasks import FairseqTask
|
21 |
+
from fairseq.logging.meters import safe_round
|
22 |
+
|
23 |
+
import logging
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class SpeechtoTextLossConfig(FairseqDataclass):
|
28 |
+
zero_infinity: bool = field(
|
29 |
+
default=False,
|
30 |
+
metadata={"help": "zero inf loss when source length <= target length"},
|
31 |
+
)
|
32 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
33 |
+
post_process: Optional[str] = field(
|
34 |
+
default="sentencepiece",
|
35 |
+
metadata={
|
36 |
+
"help": "how to post process predictions into words. can be letter, "
|
37 |
+
"wordpiece, BPE symbols, etc. "
|
38 |
+
"See fairseq.data.data_utils.post_process() for full list of options"
|
39 |
+
},
|
40 |
+
)
|
41 |
+
wer_kenlm_model: Optional[str] = field(
|
42 |
+
default=None,
|
43 |
+
metadata={
|
44 |
+
"help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
|
45 |
+
},
|
46 |
+
)
|
47 |
+
wer_lexicon: Optional[str] = field(
|
48 |
+
default=None,
|
49 |
+
metadata={"help": "lexicon to use with wer_kenlm_model"},
|
50 |
+
)
|
51 |
+
wer_lm_weight: float = field(
|
52 |
+
default=2.0,
|
53 |
+
metadata={"help": "lm weight to use with wer_kenlm_model"},
|
54 |
+
)
|
55 |
+
wer_word_score: float = field(
|
56 |
+
default=-1.0,
|
57 |
+
metadata={"help": "lm word score to use with wer_kenlm_model"},
|
58 |
+
)
|
59 |
+
|
60 |
+
wer_args: Optional[str] = field(
|
61 |
+
default=None,
|
62 |
+
metadata={
|
63 |
+
"help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
|
64 |
+
},
|
65 |
+
)
|
66 |
+
|
67 |
+
label_smoothing: float = field(
|
68 |
+
default=0.0,
|
69 |
+
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
|
70 |
+
)
|
71 |
+
report_accuracy: bool = field(
|
72 |
+
default=False,
|
73 |
+
metadata={"help": "report accuracy metric"},
|
74 |
+
)
|
75 |
+
ignore_prefix_size: int = field(
|
76 |
+
default=0,
|
77 |
+
metadata={"help": "Ignore first N tokens"},
|
78 |
+
)
|
79 |
+
#: bool = II("optimization.sentence_avg")
|
80 |
+
|
81 |
+
ce_weight: float = field(
|
82 |
+
default=1.0,
|
83 |
+
metadata={"help": "loss weight for cross entropy"},
|
84 |
+
)
|
85 |
+
ctc_weight: float = field(
|
86 |
+
default=0.0,
|
87 |
+
metadata={"help": "loss weiehgt for ctc in ASR"},
|
88 |
+
)
|
89 |
+
|
90 |
+
|
91 |
+
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
|
92 |
+
if target.dim() == lprobs.dim() - 1:
|
93 |
+
target = target.unsqueeze(-1)
|
94 |
+
nll_loss = -lprobs.gather(dim=-1, index=target)
|
95 |
+
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
96 |
+
if ignore_index is not None:
|
97 |
+
pad_mask = target.eq(ignore_index)
|
98 |
+
nll_loss.masked_fill_(pad_mask, 0.0)
|
99 |
+
smooth_loss.masked_fill_(pad_mask, 0.0)
|
100 |
+
else:
|
101 |
+
nll_loss = nll_loss.squeeze(-1)
|
102 |
+
smooth_loss = smooth_loss.squeeze(-1)
|
103 |
+
if reduce:
|
104 |
+
nll_loss = nll_loss.sum()
|
105 |
+
smooth_loss = smooth_loss.sum()
|
106 |
+
eps_i = epsilon / (lprobs.size(-1) - 1)
|
107 |
+
loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
|
108 |
+
return loss, nll_loss
|
109 |
+
|
110 |
+
|
111 |
+
class SpeechtoTextLoss(FairseqCriterion):
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
cfg: SpeechtoTextLossConfig,
|
115 |
+
task: FairseqTask,
|
116 |
+
sentence_avg=True,
|
117 |
+
label_smoothing=0.1,
|
118 |
+
ignore_prefix_size=0,
|
119 |
+
report_accuracy=False,
|
120 |
+
ce_weight=1.0,
|
121 |
+
ctc_weight=0.0,
|
122 |
+
):
|
123 |
+
|
124 |
+
super().__init__(task)
|
125 |
+
self.blank_idx = (
|
126 |
+
task.target_dictionary.index(task.blank_symbol)
|
127 |
+
if hasattr(task, "blank_symbol")
|
128 |
+
else 0
|
129 |
+
)
|
130 |
+
#print ("self.blank_idx: ", self.blank_idx)
|
131 |
+
|
132 |
+
self.pad_idx = task.target_dictionary.pad()
|
133 |
+
self.eos_idx = task.target_dictionary.eos()
|
134 |
+
self.post_process = cfg.post_process
|
135 |
+
self.ce_weight = ce_weight
|
136 |
+
self.ctc_weight = ctc_weight
|
137 |
+
|
138 |
+
## for ce
|
139 |
+
self.sentence_avg = sentence_avg
|
140 |
+
self.eps = label_smoothing
|
141 |
+
self.ignore_prefix_size = ignore_prefix_size
|
142 |
+
self.report_accuracy = report_accuracy
|
143 |
+
|
144 |
+
if cfg.wer_args is not None:
|
145 |
+
(
|
146 |
+
cfg.wer_kenlm_model,
|
147 |
+
cfg.wer_lexicon,
|
148 |
+
cfg.wer_lm_weight,
|
149 |
+
cfg.wer_word_score,
|
150 |
+
) = eval(cfg.wer_args)
|
151 |
+
|
152 |
+
if cfg.wer_kenlm_model is not None:
|
153 |
+
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
154 |
+
|
155 |
+
dec_args = Namespace()
|
156 |
+
dec_args.nbest = 1
|
157 |
+
dec_args.criterion = "ctc"
|
158 |
+
dec_args.kenlm_model = cfg.wer_kenlm_model
|
159 |
+
dec_args.lexicon = cfg.wer_lexicon
|
160 |
+
dec_args.beam = 50
|
161 |
+
dec_args.beam_size_token = min(50, len(task.target_dictionary))
|
162 |
+
dec_args.beam_threshold = min(50, len(task.target_dictionary))
|
163 |
+
dec_args.lm_weight = cfg.wer_lm_weight
|
164 |
+
dec_args.word_score = cfg.wer_word_score
|
165 |
+
dec_args.unk_weight = -math.inf
|
166 |
+
dec_args.sil_weight = 0
|
167 |
+
|
168 |
+
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
|
169 |
+
else:
|
170 |
+
self.w2l_decoder = None
|
171 |
+
|
172 |
+
self.zero_infinity = cfg.zero_infinity
|
173 |
+
#self.sentence_avg = cfg.sentence_avg
|
174 |
+
|
175 |
+
if self.ce_weight > 0 and self.ctc_weight > 0:
|
176 |
+
logger.info("Using cross entropy loss and CTC loss for ASR")
|
177 |
+
elif self.ce_weight > 0:
|
178 |
+
logger.info("Only using CE loss")
|
179 |
+
elif self.ctc_weight > 0:
|
180 |
+
logger.info("Only using CTC loss for ASR")
|
181 |
+
else:
|
182 |
+
logger.info("ERROR")
|
183 |
+
|
184 |
+
def forward(self, model, sample, reduce=True):
|
185 |
+
|
186 |
+
if self.ce_weight == 0 and self.ctc_weight > 0:
|
187 |
+
sample["only_ctc"] = True
|
188 |
+
|
189 |
+
net_output_decoder, net_output = model(**sample["net_input"])
|
190 |
+
|
191 |
+
if self.ce_weight > 0:
|
192 |
+
loss_ce, nll_loss_ce = self.compute_loss(model, net_output_decoder, sample, reduce=reduce)
|
193 |
+
#print ("loss_ce: ", loss_ce)
|
194 |
+
else:
|
195 |
+
nll_loss_ce = None
|
196 |
+
|
197 |
+
if self.ctc_weight > 0:
|
198 |
+
loss_ctc, lprobs, input_lengths = self.compute_loss_ctc(model, net_output, sample)
|
199 |
+
|
200 |
+
if self.ce_weight > 0 and self.ctc_weight > 0:
|
201 |
+
loss = self.ce_weight * loss_ce + self.ctc_weight * loss_ctc
|
202 |
+
elif self.ce_weight > 0:
|
203 |
+
loss = loss_ce
|
204 |
+
elif self.ctc_weight > 0:
|
205 |
+
loss = loss_ctc
|
206 |
+
else:
|
207 |
+
logger.info("ERROR: must ce_weight > 0 or ctc_weight > 0")
|
208 |
+
|
209 |
+
ntokens = (
|
210 |
+
sample["ntokens"] if "ntokens" in sample else sample["target_lengths"].sum().item()
|
211 |
+
)
|
212 |
+
|
213 |
+
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
|
214 |
+
|
215 |
+
logging_output = {
|
216 |
+
"loss": loss.item(),
|
217 |
+
"ce_loss": loss_ce.item() if self.ce_weight > 0 else 0,
|
218 |
+
"ctc_loss": loss_ctc.item() if self.ctc_weight > 0 else 0,
|
219 |
+
"nll_loss": nll_loss_ce.item() if nll_loss_ce is not None else 0,
|
220 |
+
"ntokens": sample["ntokens"],
|
221 |
+
"nsentences": sample["target"].size(0),
|
222 |
+
"sample_size": sample_size,
|
223 |
+
}
|
224 |
+
|
225 |
+
if self.ce_weight > 0 and self.report_accuracy:
|
226 |
+
n_correct, total = self.compute_accuracy(model, net_output_decoder, sample)
|
227 |
+
logging_output["n_correct"] = utils.item(n_correct.item())
|
228 |
+
logging_output["total"] = utils.item(total.data)
|
229 |
+
|
230 |
+
if self.ctc_weight > 0 and not model.training:
|
231 |
+
import editdistance
|
232 |
+
|
233 |
+
with torch.no_grad():
|
234 |
+
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
|
235 |
+
|
236 |
+
c_err = 0
|
237 |
+
c_len = 0
|
238 |
+
w_errs = 0
|
239 |
+
w_len = 0
|
240 |
+
wv_errs = 0
|
241 |
+
for lp, t, inp_l in zip(
|
242 |
+
lprobs_t,
|
243 |
+
sample["target_label"]
|
244 |
+
if "target_label" in sample
|
245 |
+
else sample["target"],
|
246 |
+
input_lengths,
|
247 |
+
):
|
248 |
+
lp = lp[:inp_l].unsqueeze(0)
|
249 |
+
|
250 |
+
decoded = None
|
251 |
+
if self.w2l_decoder is not None:
|
252 |
+
decoded = self.w2l_decoder.decode(lp)
|
253 |
+
if len(decoded) < 1:
|
254 |
+
decoded = None
|
255 |
+
else:
|
256 |
+
decoded = decoded[0]
|
257 |
+
if len(decoded) < 1:
|
258 |
+
decoded = None
|
259 |
+
else:
|
260 |
+
decoded = decoded[0]
|
261 |
+
|
262 |
+
p = (t != self.task.target_dictionary.pad()) & (
|
263 |
+
t != self.task.target_dictionary.eos()
|
264 |
+
)
|
265 |
+
targ = t[p]
|
266 |
+
targ_units = self.task.target_dictionary.string(targ)
|
267 |
+
targ_units_arr = targ.tolist()
|
268 |
+
|
269 |
+
toks = lp.argmax(dim=-1).unique_consecutive()
|
270 |
+
pred_units_arr = toks[toks != self.blank_idx].tolist()
|
271 |
+
|
272 |
+
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
|
273 |
+
c_len += len(targ_units_arr)
|
274 |
+
|
275 |
+
targ_words = post_process(targ_units, self.post_process).split()
|
276 |
+
|
277 |
+
pred_units = self.task.target_dictionary.string(pred_units_arr)
|
278 |
+
pred_words_raw = post_process(pred_units, self.post_process).split()
|
279 |
+
|
280 |
+
if decoded is not None and "words" in decoded:
|
281 |
+
pred_words = decoded["words"]
|
282 |
+
w_errs += editdistance.eval(pred_words, targ_words)
|
283 |
+
wv_errs += editdistance.eval(pred_words_raw, targ_words)
|
284 |
+
else:
|
285 |
+
dist = editdistance.eval(pred_words_raw, targ_words)
|
286 |
+
w_errs += dist
|
287 |
+
wv_errs += dist
|
288 |
+
|
289 |
+
w_len += len(targ_words)
|
290 |
+
|
291 |
+
logging_output["wv_errors"] = wv_errs
|
292 |
+
logging_output["w_errors"] = w_errs
|
293 |
+
logging_output["w_total"] = w_len
|
294 |
+
logging_output["c_errors"] = c_err
|
295 |
+
logging_output["c_total"] = c_len
|
296 |
+
|
297 |
+
return loss, sample_size, logging_output
|
298 |
+
|
299 |
+
def compute_loss_ctc(self, model, net_output, sample):
|
300 |
+
lprobs = model.get_normalized_probs_for_ctc(
|
301 |
+
net_output, log_probs=True
|
302 |
+
).contiguous() # (T, B, C) from the encoder
|
303 |
+
|
304 |
+
if net_output["encoder_padding_mask"] is not None:
|
305 |
+
non_padding_mask = ~net_output["encoder_padding_mask"][0]
|
306 |
+
input_lengths = non_padding_mask.long().sum(-1)
|
307 |
+
else:
|
308 |
+
input_lengths = lprobs.new_full(
|
309 |
+
(lprobs.size(1),), lprobs.size(0), dtype=torch.long
|
310 |
+
)
|
311 |
+
|
312 |
+
pad_mask = (sample["target"] != self.pad_idx) & (
|
313 |
+
sample["target"] != self.eos_idx
|
314 |
+
)
|
315 |
+
targets_flat = sample["target"].masked_select(pad_mask)
|
316 |
+
if "target_lengths" in sample:
|
317 |
+
target_lengths = sample["target_lengths"]
|
318 |
+
else:
|
319 |
+
target_lengths = pad_mask.sum(-1)
|
320 |
+
|
321 |
+
##processing
|
322 |
+
target_lengths = target_lengths - 1
|
323 |
+
|
324 |
+
with torch.backends.cudnn.flags(enabled=False):
|
325 |
+
loss_ctc = F.ctc_loss(
|
326 |
+
lprobs,
|
327 |
+
targets_flat,
|
328 |
+
input_lengths,
|
329 |
+
target_lengths,
|
330 |
+
blank=self.blank_idx,
|
331 |
+
reduction="sum",
|
332 |
+
zero_infinity=True,
|
333 |
+
)
|
334 |
+
|
335 |
+
return loss_ctc, lprobs, input_lengths
|
336 |
+
|
337 |
+
## for ce
|
338 |
+
def get_lprobs_and_target(self, model, net_output, sample):
|
339 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
340 |
+
target = model.get_targets(sample, net_output)
|
341 |
+
if self.ignore_prefix_size > 0:
|
342 |
+
if getattr(lprobs, "batch_first", False):
|
343 |
+
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
344 |
+
target = target[:, self.ignore_prefix_size :].contiguous()
|
345 |
+
else:
|
346 |
+
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
347 |
+
target = target[self.ignore_prefix_size :, :].contiguous()
|
348 |
+
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
|
349 |
+
|
350 |
+
def compute_loss(self, model, net_output, sample, reduce=True):
|
351 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
352 |
+
loss, nll_loss = label_smoothed_nll_loss(
|
353 |
+
lprobs,
|
354 |
+
target,
|
355 |
+
self.eps,
|
356 |
+
ignore_index=self.padding_idx,
|
357 |
+
reduce=reduce,
|
358 |
+
)
|
359 |
+
return loss, nll_loss
|
360 |
+
|
361 |
+
def compute_accuracy(self, model, net_output, sample):
|
362 |
+
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
363 |
+
mask = target.ne(self.padding_idx)
|
364 |
+
n_correct = torch.sum(
|
365 |
+
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
366 |
+
)
|
367 |
+
total = torch.sum(mask)
|
368 |
+
return n_correct, total
|
369 |
+
|
370 |
+
|
371 |
+
@staticmethod
|
372 |
+
def reduce_metrics(logging_outputs) -> None:
|
373 |
+
"""Aggregate logging outputs from data parallel training."""
|
374 |
+
|
375 |
+
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
376 |
+
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
377 |
+
ce_loss_sum = sum(log.get("ce_loss", 0) for log in logging_outputs)
|
378 |
+
ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
|
379 |
+
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
380 |
+
nsentences = utils.item(
|
381 |
+
sum(log.get("nsentences", 0) for log in logging_outputs)
|
382 |
+
)
|
383 |
+
sample_size = utils.item(
|
384 |
+
sum(log.get("sample_size", 0) for log in logging_outputs)
|
385 |
+
)
|
386 |
+
|
387 |
+
metrics.log_scalar(
|
388 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
389 |
+
)
|
390 |
+
|
391 |
+
metrics.log_scalar(
|
392 |
+
"ctc_loss", ctc_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
|
393 |
+
)
|
394 |
+
metrics.log_scalar(
|
395 |
+
"ce_loss", ce_loss_sum / ntokens, ntokens, 2, round=3
|
396 |
+
)
|
397 |
+
metrics.log_scalar(
|
398 |
+
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, 2, round=3
|
399 |
+
)
|
400 |
+
metrics.log_derived(
|
401 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg, 2)
|
402 |
+
)
|
403 |
+
|
404 |
+
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
405 |
+
if total > 0:
|
406 |
+
metrics.log_scalar("total", total)
|
407 |
+
n_correct = utils.item(
|
408 |
+
sum(log.get("n_correct", 0) for log in logging_outputs)
|
409 |
+
)
|
410 |
+
metrics.log_scalar("n_correct", n_correct)
|
411 |
+
metrics.log_derived(
|
412 |
+
"accuracy",
|
413 |
+
lambda meters: round(
|
414 |
+
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
415 |
+
)
|
416 |
+
if meters["total"].sum > 0
|
417 |
+
else float("nan"),
|
418 |
+
2
|
419 |
+
)
|
420 |
+
|
421 |
+
metrics.log_scalar("ntokens", ntokens)
|
422 |
+
metrics.log_scalar("nsentences", nsentences)
|
423 |
+
if sample_size != ntokens:
|
424 |
+
metrics.log_scalar(
|
425 |
+
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
426 |
+
)
|
427 |
+
|
428 |
+
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
|
429 |
+
metrics.log_scalar("_c_errors", c_errors)
|
430 |
+
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
|
431 |
+
metrics.log_scalar("_c_total", c_total)
|
432 |
+
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
|
433 |
+
metrics.log_scalar("_w_errors", w_errors)
|
434 |
+
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
|
435 |
+
metrics.log_scalar("_wv_errors", wv_errors)
|
436 |
+
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
|
437 |
+
metrics.log_scalar("_w_total", w_total)
|
438 |
+
|
439 |
+
if c_total > 0:
|
440 |
+
metrics.log_derived(
|
441 |
+
"uer",
|
442 |
+
lambda meters: safe_round(
|
443 |
+
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
|
444 |
+
)
|
445 |
+
if meters["_c_total"].sum > 0
|
446 |
+
else float("nan"),
|
447 |
+
)
|
448 |
+
if w_total > 0:
|
449 |
+
metrics.log_derived(
|
450 |
+
"wer",
|
451 |
+
lambda meters: safe_round(
|
452 |
+
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
453 |
+
)
|
454 |
+
if meters["_w_total"].sum > 0
|
455 |
+
else float("nan"),
|
456 |
+
)
|
457 |
+
metrics.log_derived(
|
458 |
+
"raw_wer",
|
459 |
+
lambda meters: safe_round(
|
460 |
+
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
461 |
+
)
|
462 |
+
if meters["_w_total"].sum > 0
|
463 |
+
else float("nan"),
|
464 |
+
)
|
465 |
+
|
466 |
+
@staticmethod
|
467 |
+
def logging_outputs_can_be_summed() -> bool:
|
468 |
+
"""
|
469 |
+
Whether the logging outputs returned by `forward` can be summed
|
470 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
471 |
+
to True will improves distributed training speed.
|
472 |
+
"""
|
473 |
+
return True
|
artst/criterions/text_pretrain_criterion.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
from dataclasses import dataclass, field
|
10 |
+
from typing import List, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from fairseq import metrics, utils
|
15 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
16 |
+
from fairseq.dataclass import FairseqDataclass
|
17 |
+
from omegaconf import II
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class TextPretrainCriterionConfig(FairseqDataclass):
|
22 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
23 |
+
loss_weights: Optional[List[float]] = field(
|
24 |
+
default_factory=lambda: [0.1,],
|
25 |
+
metadata={"help": "weights for additional loss terms (not first one)"},
|
26 |
+
)
|
27 |
+
bart_weight: float = field(
|
28 |
+
default=1.0,
|
29 |
+
metadata={"help": "loss weight for cross entropy"},
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
class TextPretrainCriterion(FairseqCriterion):
|
34 |
+
def __init__(self, task, sentence_avg, bart_weight, loss_weights=None):
|
35 |
+
super().__init__(task)
|
36 |
+
self.sentence_avg = sentence_avg
|
37 |
+
self.loss_weights = loss_weights
|
38 |
+
self.bart_weight = bart_weight
|
39 |
+
|
40 |
+
def forward(self, model, sample, reduce=True):
|
41 |
+
"""Compute the loss for the given sample.
|
42 |
+
|
43 |
+
Returns a tuple with three elements:
|
44 |
+
1) the loss
|
45 |
+
2) the sample size, which is used as the denominator for the gradient
|
46 |
+
3) logging outputs to display while training
|
47 |
+
"""
|
48 |
+
net_output, codebook_out, encoder_output = model(**sample["net_input"])
|
49 |
+
bart_loss, _ = self.compute_loss(model, net_output, sample, reduce=reduce)
|
50 |
+
sample_size = (
|
51 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
52 |
+
)
|
53 |
+
|
54 |
+
loss = self.bart_weight * bart_loss
|
55 |
+
logging_output = {
|
56 |
+
"loss": loss.item(),
|
57 |
+
"ntokens": sample["ntokens"],
|
58 |
+
"nsentences": sample["target"].size(0),
|
59 |
+
"bart_loss": bart_loss.item(),
|
60 |
+
"sample_size": sample_size,
|
61 |
+
}
|
62 |
+
|
63 |
+
if "prob_perplexity" in codebook_out:
|
64 |
+
assert hasattr(model, "get_extra_losses")
|
65 |
+
extra_losses, names = model.get_extra_losses(codebook_out)
|
66 |
+
if torch.is_tensor(extra_losses):
|
67 |
+
extra_losses = [extra_losses]
|
68 |
+
names = [names]
|
69 |
+
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
70 |
+
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
71 |
+
if len(self.loss_weights) > len(extra_losses):
|
72 |
+
modified_loss_weight = self.loss_weights[len(extra_losses):]
|
73 |
+
else:
|
74 |
+
modified_loss_weight = self.loss_weights
|
75 |
+
|
76 |
+
# assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
77 |
+
for p, n, coef in zip(extra_losses, names, modified_loss_weight):
|
78 |
+
# print(n + str(coef))
|
79 |
+
if coef != 0 and p is not None:
|
80 |
+
p = coef * p.float() * sample_size
|
81 |
+
loss += p
|
82 |
+
logging_output[f"loss_{n}"] = p.item()
|
83 |
+
|
84 |
+
if 'loss_prob_perplexity' in logging_output:
|
85 |
+
logging_output['code_perplexity'] = codebook_out['code_perplexity'].item()
|
86 |
+
|
87 |
+
return loss, sample_size, logging_output
|
88 |
+
|
89 |
+
def compute_loss(self, model, net_output, sample, reduce=True):
|
90 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
91 |
+
lprobs = lprobs.view(-1, lprobs.size(-1))
|
92 |
+
target = model.get_targets(sample, net_output).view(-1)
|
93 |
+
loss = F.nll_loss(
|
94 |
+
lprobs,
|
95 |
+
target,
|
96 |
+
ignore_index=self.padding_idx,
|
97 |
+
reduction="sum" if reduce else "none",
|
98 |
+
)
|
99 |
+
return loss, loss
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def reduce_metrics(logging_outputs) -> None:
|
103 |
+
"""Aggregate logging outputs from data parallel training."""
|
104 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
105 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
106 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
107 |
+
bart_loss_sum = sum(log.get("bart_loss", 0) for log in logging_outputs)
|
108 |
+
|
109 |
+
# we divide by log(2) to convert the loss from base e to base 2
|
110 |
+
metrics.log_scalar(
|
111 |
+
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
112 |
+
)
|
113 |
+
metrics.log_scalar(
|
114 |
+
"bart_loss", bart_loss_sum / sample_size / math.log(2), ntokens, 2, round=3
|
115 |
+
)
|
116 |
+
if sample_size != ntokens:
|
117 |
+
metrics.log_scalar(
|
118 |
+
"nll_loss", bart_loss_sum / ntokens / math.log(2), ntokens, round=3
|
119 |
+
)
|
120 |
+
metrics.log_derived(
|
121 |
+
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
metrics.log_derived(
|
125 |
+
"ppl", lambda meters: utils.get_perplexity(meters["bart_loss"].avg)
|
126 |
+
)
|
127 |
+
|
128 |
+
if "loss_prob_perplexity" in logging_outputs[0].keys():
|
129 |
+
val = sum(log["loss_prob_perplexity"] for log in logging_outputs)
|
130 |
+
metrics.log_scalar("loss_prob_perplexity", val / sample_size / math.log(2), round=3)
|
131 |
+
if "code_perplexity" in logging_outputs[0].keys():
|
132 |
+
val = sum(log["code_perplexity"] for log in logging_outputs)
|
133 |
+
metrics.log_scalar("code_perplexity", val / len(logging_outputs), round=3)
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def logging_outputs_can_be_summed() -> bool:
|
137 |
+
"""
|
138 |
+
Whether the logging outputs returned by `forward` can be summed
|
139 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
140 |
+
to True will improves distributed training speed.
|
141 |
+
"""
|
142 |
+
return True
|
artst/criterions/text_to_speech_loss.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transform (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from fairseq import metrics, utils
|
12 |
+
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
|
13 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
14 |
+
from fairseq.dataclass import FairseqDataclass
|
15 |
+
from artst.models.modules.speech_encoder_prenet import SpeechEncoderPrenet
|
16 |
+
from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss
|
17 |
+
from omegaconf import II
|
18 |
+
from typing import Any
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class TexttoSpeechLossConfig(FairseqDataclass):
|
23 |
+
use_masking: bool = field(
|
24 |
+
default=True,
|
25 |
+
metadata={"help": "Whether to use masking in calculation of loss"},
|
26 |
+
)
|
27 |
+
use_weighted_masking: bool = field(
|
28 |
+
default=False,
|
29 |
+
metadata={"help": "Whether to use weighted masking in calculation of loss"},
|
30 |
+
)
|
31 |
+
loss_type: str = field(
|
32 |
+
default="L1",
|
33 |
+
metadata={"help": "How to calc loss"},
|
34 |
+
)
|
35 |
+
bce_pos_weight: float = field(
|
36 |
+
default=5.0,
|
37 |
+
metadata={"help": "Positive sample weight in BCE calculation (only for use-masking=True)"},
|
38 |
+
)
|
39 |
+
bce_loss_lambda: float = field(
|
40 |
+
default=1.0,
|
41 |
+
metadata={"help": "Lambda in bce loss"},
|
42 |
+
)
|
43 |
+
use_guided_attn_loss: bool = field(
|
44 |
+
default=False,
|
45 |
+
metadata={"help": "Whether to use guided attention loss"},
|
46 |
+
)
|
47 |
+
guided_attn_loss_sigma: float = field(
|
48 |
+
default=0.4,
|
49 |
+
metadata={"help": "Sigma in guided attention loss"},
|
50 |
+
)
|
51 |
+
guided_attn_loss_lambda: float = field(
|
52 |
+
default=10.0,
|
53 |
+
metadata={"help": "Lambda in guided attention loss"},
|
54 |
+
)
|
55 |
+
num_layers_applied_guided_attn: int = field(
|
56 |
+
default=2,
|
57 |
+
metadata={"help": "Number of layers to be applied guided attention loss, if set -1, all of the layers will be applied."},
|
58 |
+
)
|
59 |
+
num_heads_applied_guided_attn: int = field(
|
60 |
+
default=2,
|
61 |
+
metadata={"help": "Number of heads in each layer to be applied guided attention loss, if set -1, all of the heads will be applied."},
|
62 |
+
)
|
63 |
+
modules_applied_guided_attn: Any = field(
|
64 |
+
default=("encoder-decoder",),
|
65 |
+
metadata={"help": "Module name list to be applied guided attention loss"},
|
66 |
+
)
|
67 |
+
sentence_avg: bool = II("optimization.sentence_avg")
|
68 |
+
|
69 |
+
|
70 |
+
class TexttoSpeechLoss(FairseqCriterion):
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
task,
|
74 |
+
sentence_avg,
|
75 |
+
use_masking=True,
|
76 |
+
use_weighted_masking=False,
|
77 |
+
loss_type="L1",
|
78 |
+
bce_pos_weight=5.0,
|
79 |
+
bce_loss_lambda=1.0,
|
80 |
+
use_guided_attn_loss=False,
|
81 |
+
guided_attn_loss_sigma=0.4,
|
82 |
+
guided_attn_loss_lambda=1.0,
|
83 |
+
num_layers_applied_guided_attn=2,
|
84 |
+
num_heads_applied_guided_attn=2,
|
85 |
+
modules_applied_guided_attn=["encoder-decoder"],
|
86 |
+
):
|
87 |
+
super().__init__(task)
|
88 |
+
self.sentence_avg = sentence_avg
|
89 |
+
self.use_masking = use_masking
|
90 |
+
self.use_weighted_masking = use_weighted_masking
|
91 |
+
self.loss_type = loss_type
|
92 |
+
self.bce_pos_weight = bce_pos_weight
|
93 |
+
self.bce_loss_lambda = bce_loss_lambda
|
94 |
+
self.use_guided_attn_loss = use_guided_attn_loss
|
95 |
+
self.guided_attn_loss_sigma = guided_attn_loss_sigma
|
96 |
+
self.guided_attn_loss_lambda = guided_attn_loss_lambda
|
97 |
+
# define loss function
|
98 |
+
self.criterion = Tacotron2Loss(
|
99 |
+
use_masking=use_masking,
|
100 |
+
use_weighted_masking=use_weighted_masking,
|
101 |
+
bce_pos_weight=bce_pos_weight,
|
102 |
+
)
|
103 |
+
if self.use_guided_attn_loss:
|
104 |
+
self.num_layers_applied_guided_attn = num_layers_applied_guided_attn
|
105 |
+
self.num_heads_applied_guided_attn = num_heads_applied_guided_attn
|
106 |
+
self.modules_applied_guided_attn = modules_applied_guided_attn
|
107 |
+
if self.use_guided_attn_loss:
|
108 |
+
self.attn_criterion = GuidedMultiHeadAttentionLoss(
|
109 |
+
sigma=guided_attn_loss_sigma,
|
110 |
+
alpha=guided_attn_loss_lambda,
|
111 |
+
)
|
112 |
+
|
113 |
+
def forward(self, model, sample):
|
114 |
+
"""Compute the loss for the given sample.
|
115 |
+
|
116 |
+
Returns a tuple with three elements:
|
117 |
+
1) the loss
|
118 |
+
2) the sample size, which is used as the denominator for the gradient
|
119 |
+
3) logging outputs to display while training
|
120 |
+
"""
|
121 |
+
net_output = model(**sample["net_input"])
|
122 |
+
loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss = self.compute_loss(model, net_output, sample)
|
123 |
+
# sample_size = (
|
124 |
+
# sample["target"].size(0) if self.sentence_avg else sample["nframes"]
|
125 |
+
# )
|
126 |
+
sample_size = 1
|
127 |
+
logging_output = {
|
128 |
+
"loss": loss.item(),
|
129 |
+
"l1_loss": l1_loss.item(),
|
130 |
+
"l2_loss": l2_loss.item(),
|
131 |
+
"bce_loss": bce_loss.item(),
|
132 |
+
"sample_size": 1,
|
133 |
+
"ntokens": sample["ntokens"],
|
134 |
+
"nsentences": sample["target"].size(0),
|
135 |
+
}
|
136 |
+
|
137 |
+
if enc_dec_attn_loss is not None:
|
138 |
+
logging_output['enc_dec_attn_loss'] = enc_dec_attn_loss.item()
|
139 |
+
|
140 |
+
if hasattr(model, 'text_encoder_prenet'):
|
141 |
+
logging_output["encoder_alpha"] = model.text_encoder_prenet.encoder_prenet[-1].alpha.item()
|
142 |
+
logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item()
|
143 |
+
elif hasattr(model, "speech_encoder_prenet"):
|
144 |
+
logging_output["decoder_alpha"] = model.speech_decoder_prenet.decoder_prenet[-1].alpha.item()
|
145 |
+
else:
|
146 |
+
if 'task' not in sample:
|
147 |
+
logging_output["encoder_alpha"] = model.encoder_prenet.encoder_prenet[-1].alpha.item()
|
148 |
+
logging_output["decoder_alpha"] = model.decoder_prenet.decoder_prenet[-1].alpha.item()
|
149 |
+
|
150 |
+
return loss, sample_size, logging_output
|
151 |
+
|
152 |
+
def compute_loss(self, model, net_output, sample):
|
153 |
+
before_outs, after_outs, logits, attn = net_output
|
154 |
+
labels = sample["labels"]
|
155 |
+
ys = sample["dec_target"]
|
156 |
+
olens = sample["dec_target_lengths"]
|
157 |
+
ilens = sample["src_lengths"]
|
158 |
+
|
159 |
+
# modifiy mod part of groundtruth
|
160 |
+
if model.reduction_factor > 1:
|
161 |
+
olens_in = olens.new([torch.div(olen, model.reduction_factor, rounding_mode='floor') for olen in olens])
|
162 |
+
olens = olens.new([olen - olen % model.reduction_factor for olen in olens])
|
163 |
+
max_olen = max(olens)
|
164 |
+
ys = ys[:, :max_olen]
|
165 |
+
labels = labels[:, :max_olen]
|
166 |
+
labels = torch.scatter(labels, 1, (olens - 1).unsqueeze(1), 1.0) # make sure at least one frame has 1
|
167 |
+
# labels[:, -1] = 1.0
|
168 |
+
else:
|
169 |
+
olens_in = olens
|
170 |
+
|
171 |
+
# caluculate loss values
|
172 |
+
l1_loss, l2_loss, bce_loss = self.criterion(
|
173 |
+
after_outs, before_outs, logits, ys, labels, olens
|
174 |
+
)
|
175 |
+
|
176 |
+
# l1_loss = l1_loss / ys.size(2)
|
177 |
+
# l2_loss = l2_loss / ys.size(2)
|
178 |
+
|
179 |
+
if self.loss_type == "L1":
|
180 |
+
loss = l1_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss
|
181 |
+
elif self.loss_type == "L2":
|
182 |
+
loss = l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l2_loss
|
183 |
+
elif self.loss_type == "L1+L2":
|
184 |
+
loss = l1_loss + l2_loss + self.bce_loss_lambda * bce_loss if self.bce_loss_lambda > 0.0 else l1_loss + l2_loss
|
185 |
+
else:
|
186 |
+
raise ValueError("unknown --loss-type " + self.loss_type)
|
187 |
+
|
188 |
+
# calculate guided attention loss
|
189 |
+
enc_dec_attn_loss = None
|
190 |
+
if self.use_guided_attn_loss:
|
191 |
+
# calculate the input lengths of encoder, which is determined by encoder prenet
|
192 |
+
if hasattr(model, 'encoder_reduction_factor') and model.encoder_reduction_factor > 1:
|
193 |
+
ilens_in = ilens.new([ilen // model.encoder_reduction_factor for ilen in ilens])
|
194 |
+
else:
|
195 |
+
ilens_in = ilens
|
196 |
+
# work for speech to speech model's input
|
197 |
+
if "task_name" in sample and sample["task_name"] == "s2s":
|
198 |
+
m = None
|
199 |
+
if hasattr(model, 'encoder_prenet'):
|
200 |
+
m = model.encoder_prenet
|
201 |
+
elif hasattr(model, 'speech_encoder_prenet'):
|
202 |
+
m = model.speech_encoder_prenet
|
203 |
+
if m is not None and isinstance(m, SpeechEncoderPrenet):
|
204 |
+
ilens_in = m.get_src_lengths(ilens_in)
|
205 |
+
# calculate for encoder-decoder
|
206 |
+
if "encoder-decoder" in self.modules_applied_guided_attn:
|
207 |
+
attn = [att_l[:, : self.num_heads_applied_guided_attn] for att_l in attn]
|
208 |
+
att_ws = torch.cat(attn, dim=1) # (B, H*L, T_out, T_in)
|
209 |
+
enc_dec_attn_loss = self.attn_criterion(att_ws, ilens_in, olens_in)
|
210 |
+
loss = loss + enc_dec_attn_loss
|
211 |
+
|
212 |
+
return loss, l1_loss, l2_loss, bce_loss, enc_dec_attn_loss
|
213 |
+
|
214 |
+
@classmethod
|
215 |
+
def reduce_metrics(cls, logging_outputs) -> None:
|
216 |
+
"""Aggregate logging outputs from data parallel training."""
|
217 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
218 |
+
l1_loss_sum = sum(log.get("l1_loss", 0) for log in logging_outputs)
|
219 |
+
l2_loss_sum = sum(log.get("l2_loss", 0) for log in logging_outputs)
|
220 |
+
bce_loss_sum = sum(log.get("bce_loss", 0) for log in logging_outputs)
|
221 |
+
sample_size = max(1, sum(log.get("sample_size", 0) for log in logging_outputs))
|
222 |
+
metrics.log_scalar(
|
223 |
+
"loss", loss_sum / sample_size, sample_size, 1, round=5
|
224 |
+
)
|
225 |
+
encoder_alpha_sum = sum(log.get("encoder_alpha", 0) for log in logging_outputs)
|
226 |
+
decoder_alpha_sum = sum(log.get("decoder_alpha", 0) for log in logging_outputs)
|
227 |
+
ngpu = sum(log.get("ngpu", 0) for log in logging_outputs)
|
228 |
+
|
229 |
+
metrics.log_scalar(
|
230 |
+
"l1_loss", l1_loss_sum / sample_size, sample_size, 2, round=5
|
231 |
+
)
|
232 |
+
metrics.log_scalar(
|
233 |
+
"l2_loss", l2_loss_sum / sample_size, sample_size, 2, round=5
|
234 |
+
)
|
235 |
+
metrics.log_scalar(
|
236 |
+
"bce_loss", bce_loss_sum / sample_size, sample_size, 2, round=5
|
237 |
+
)
|
238 |
+
metrics.log_scalar(
|
239 |
+
"encoder_alpha", encoder_alpha_sum / sample_size, sample_size, round=5
|
240 |
+
)
|
241 |
+
metrics.log_scalar(
|
242 |
+
"decoder_alpha", decoder_alpha_sum / sample_size, sample_size, round=5
|
243 |
+
)
|
244 |
+
|
245 |
+
if "enc_dec_attn_loss" in logging_outputs[0]:
|
246 |
+
enc_dec_attn_loss_sum = sum(log.get("enc_dec_attn_loss", 0) for log in logging_outputs)
|
247 |
+
metrics.log_scalar(
|
248 |
+
"enc_dec_attn_loss", enc_dec_attn_loss_sum / sample_size, sample_size, round=8
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
@staticmethod
|
253 |
+
def logging_outputs_can_be_summed() -> bool:
|
254 |
+
"""
|
255 |
+
Whether the logging outputs returned by `forward` can be summed
|
256 |
+
across workers prior to calling `reduce_metrics`. Setting this
|
257 |
+
to True will improves distributed training speed.
|
258 |
+
"""
|
259 |
+
return True
|
260 |
+
|
261 |
+
class Tacotron2Loss(torch.nn.Module):
|
262 |
+
"""Loss function module for Tacotron2."""
|
263 |
+
|
264 |
+
def __init__(
|
265 |
+
self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0
|
266 |
+
):
|
267 |
+
"""Initialize Tactoron2 loss module.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
use_masking (bool): Whether to apply masking
|
271 |
+
for padded part in loss calculation.
|
272 |
+
use_weighted_masking (bool):
|
273 |
+
Whether to apply weighted masking in loss calculation.
|
274 |
+
bce_pos_weight (float): Weight of positive sample of stop token.
|
275 |
+
|
276 |
+
"""
|
277 |
+
super(Tacotron2Loss, self).__init__()
|
278 |
+
assert (use_masking != use_weighted_masking) or not use_masking
|
279 |
+
self.use_masking = use_masking
|
280 |
+
self.use_weighted_masking = use_weighted_masking
|
281 |
+
|
282 |
+
# define criterions
|
283 |
+
# reduction = "none" if self.use_weighted_masking else "sum"
|
284 |
+
reduction = "none" if self.use_weighted_masking else "mean"
|
285 |
+
self.l1_criterion = torch.nn.L1Loss(reduction=reduction)
|
286 |
+
self.mse_criterion = torch.nn.MSELoss(reduction=reduction)
|
287 |
+
self.bce_criterion = torch.nn.BCEWithLogitsLoss(
|
288 |
+
reduction=reduction, pos_weight=torch.tensor(bce_pos_weight)
|
289 |
+
)
|
290 |
+
|
291 |
+
# NOTE(kan-bayashi): register pre hook function for the compatibility
|
292 |
+
self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
|
293 |
+
|
294 |
+
def forward(self, after_outs, before_outs, logits, ys, labels, olens):
|
295 |
+
"""Calculate forward propagation.
|
296 |
+
|
297 |
+
Args:
|
298 |
+
after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim).
|
299 |
+
before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim).
|
300 |
+
logits (Tensor): Batch of stop logits (B, Lmax).
|
301 |
+
ys (Tensor): Batch of padded target features (B, Lmax, odim).
|
302 |
+
labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax).
|
303 |
+
olens (LongTensor): Batch of the lengths of each target (B,).
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
Tensor: L1 loss value.
|
307 |
+
Tensor: Mean square error loss value.
|
308 |
+
Tensor: Binary cross entropy loss value.
|
309 |
+
|
310 |
+
"""
|
311 |
+
# make mask and apply it
|
312 |
+
if self.use_masking:
|
313 |
+
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
|
314 |
+
ys = ys.masked_select(masks)
|
315 |
+
after_outs = after_outs.masked_select(masks)
|
316 |
+
before_outs = before_outs.masked_select(masks)
|
317 |
+
labels = labels.masked_select(masks[:, :, 0])
|
318 |
+
logits = logits.masked_select(masks[:, :, 0])
|
319 |
+
|
320 |
+
# calculate loss
|
321 |
+
l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys)
|
322 |
+
mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion(
|
323 |
+
before_outs, ys
|
324 |
+
)
|
325 |
+
bce_loss = self.bce_criterion(logits, labels)
|
326 |
+
|
327 |
+
# make weighted mask and apply it
|
328 |
+
if self.use_weighted_masking:
|
329 |
+
masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
|
330 |
+
weights = masks.float() / masks.sum(dim=1, keepdim=True).float()
|
331 |
+
out_weights = weights.div(ys.size(0) * ys.size(2))
|
332 |
+
logit_weights = weights.div(ys.size(0))
|
333 |
+
|
334 |
+
# apply weight
|
335 |
+
l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum()
|
336 |
+
mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum()
|
337 |
+
bce_loss = (
|
338 |
+
bce_loss.mul(logit_weights.squeeze(-1))
|
339 |
+
.masked_select(masks.squeeze(-1))
|
340 |
+
.sum()
|
341 |
+
)
|
342 |
+
|
343 |
+
return l1_loss, mse_loss, bce_loss
|
344 |
+
|
345 |
+
def _load_state_dict_pre_hook(
|
346 |
+
self,
|
347 |
+
state_dict,
|
348 |
+
prefix,
|
349 |
+
local_metadata,
|
350 |
+
strict,
|
351 |
+
missing_keys,
|
352 |
+
unexpected_keys,
|
353 |
+
error_msgs,
|
354 |
+
):
|
355 |
+
"""Apply pre hook fucntion before loading state dict.
|
356 |
+
|
357 |
+
From v.0.6.1 `bce_criterion.pos_weight` param is registered as a parameter but
|
358 |
+
old models do not include it and as a result, it causes missing key error when
|
359 |
+
loading old model parameter. This function solve the issue by adding param in
|
360 |
+
state dict before loading as a pre hook function
|
361 |
+
of the `load_state_dict` method.
|
362 |
+
|
363 |
+
"""
|
364 |
+
key = prefix + "bce_criterion.pos_weight"
|
365 |
+
if key not in state_dict:
|
366 |
+
state_dict[key] = self.bce_criterion.pos_weight
|
367 |
+
|
368 |
+
class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss):
|
369 |
+
"""Guided attention loss function module for multi head attention.
|
370 |
+
Args:
|
371 |
+
sigma (float, optional): Standard deviation to control
|
372 |
+
how close attention to a diagonal.
|
373 |
+
alpha (float, optional): Scaling coefficient (lambda).
|
374 |
+
reset_always (bool, optional): Whether to always reset masks.
|
375 |
+
"""
|
376 |
+
|
377 |
+
def forward(self, att_ws, ilens, olens):
|
378 |
+
"""Calculate forward propagation.
|
379 |
+
Args:
|
380 |
+
att_ws (Tensor):
|
381 |
+
Batch of multi head attention weights (B, H, T_max_out, T_max_in).
|
382 |
+
ilens (LongTensor): Batch of input lenghts (B,).
|
383 |
+
olens (LongTensor): Batch of output lenghts (B,).
|
384 |
+
Returns:
|
385 |
+
Tensor: Guided attention loss value.
|
386 |
+
"""
|
387 |
+
if self.guided_attn_masks is None:
|
388 |
+
self.guided_attn_masks = (
|
389 |
+
self._make_guided_attention_masks(ilens, olens)
|
390 |
+
.to(att_ws.device)
|
391 |
+
.unsqueeze(1)
|
392 |
+
)
|
393 |
+
if self.masks is None:
|
394 |
+
self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1)
|
395 |
+
losses = self.guided_attn_masks * att_ws
|
396 |
+
loss = torch.mean(losses.masked_select(self.masks))
|
397 |
+
if self.reset_always:
|
398 |
+
self._reset_masks()
|
399 |
+
|
400 |
+
return self.alpha * loss
|
401 |
+
|
402 |
+
def _make_guided_attention_masks(self, ilens, olens):
|
403 |
+
n_batches = len(ilens)
|
404 |
+
max_ilen = max(ilens)
|
405 |
+
max_olen = max(olens)
|
406 |
+
guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen), device=olens.device)
|
407 |
+
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
|
408 |
+
guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(
|
409 |
+
ilen, olen, self.sigma
|
410 |
+
)
|
411 |
+
return guided_attn_masks
|
412 |
+
|
413 |
+
@staticmethod
|
414 |
+
def _make_guided_attention_mask(ilen, olen, sigma):
|
415 |
+
grid_x, grid_y = torch.meshgrid(torch.arange(olen, device=olen.device), torch.arange(ilen, device=olen.device))
|
416 |
+
grid_x, grid_y = grid_x.float(), grid_y.float()
|
417 |
+
return 1.0 - torch.exp(
|
418 |
+
-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma**2))
|
419 |
+
)
|
420 |
+
|
421 |
+
@staticmethod
|
422 |
+
def _make_masks(ilens, olens):
|
423 |
+
in_masks = make_non_pad_mask(ilens).to(ilens.device) # (B, T_in)
|
424 |
+
out_masks = make_non_pad_mask(olens).to(olens.device) # (B, T_out)
|
425 |
+
return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in)
|
artst/data/__init__.py
ADDED
File without changes
|
artst/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (140 Bytes). View file
|
|
artst/data/__pycache__/multitask_dataset.cpython-38.pyc
ADDED
Binary file (8.91 kB). View file
|
|
artst/data/__pycache__/speech_dataset.cpython-38.pyc
ADDED
Binary file (16.7 kB). View file
|
|
artst/data/__pycache__/speech_to_class_dataset.cpython-38.pyc
ADDED
Binary file (8.35 kB). View file
|
|
artst/data/__pycache__/speech_to_speech_dataset.cpython-38.pyc
ADDED
Binary file (9.76 kB). View file
|
|
artst/data/__pycache__/speech_to_text_dataset.cpython-38.pyc
ADDED
Binary file (9.7 kB). View file
|
|
artst/data/__pycache__/text_dataset.cpython-38.pyc
ADDED
Binary file (11.6 kB). View file
|
|
artst/data/__pycache__/text_to_speech_dataset.cpython-38.pyc
ADDED
Binary file (12.1 kB). View file
|
|
artst/data/multitask_dataset.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import bisect
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import numpy as np
|
12 |
+
from torch.utils.data.dataloader import default_collate
|
13 |
+
from fairseq.data import data_utils
|
14 |
+
|
15 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
class MultitaskDataset(FairseqDataset):
|
20 |
+
@staticmethod
|
21 |
+
def cumsum(sequence):
|
22 |
+
r, s = [], 0
|
23 |
+
for e in sequence:
|
24 |
+
curr_len = len(e)
|
25 |
+
r.append(curr_len + s)
|
26 |
+
s += curr_len
|
27 |
+
return r
|
28 |
+
|
29 |
+
def __init__(self, datasets, sample_ratios=1, batch_ratio=None):
|
30 |
+
super(MultitaskDataset, self).__init__()
|
31 |
+
assert len(datasets) > 0, "datasets should not be an empty iterable"
|
32 |
+
self.datasets = list(datasets)
|
33 |
+
if isinstance(sample_ratios, int):
|
34 |
+
sample_ratios = [sample_ratios] * len(self.datasets)
|
35 |
+
if batch_ratio is not None:
|
36 |
+
logger.info('batch ratio is ' + str(batch_ratio))
|
37 |
+
self.batch_ratio = batch_ratio
|
38 |
+
else:
|
39 |
+
self.batch_ratio = None
|
40 |
+
else:
|
41 |
+
logger.info('set sample ratio to ' + str(sample_ratios))
|
42 |
+
if batch_ratio is not None:
|
43 |
+
logger.info('batch ratio is ' + str(batch_ratio))
|
44 |
+
self.batch_ratio = batch_ratio
|
45 |
+
else:
|
46 |
+
self.batch_ratio = None
|
47 |
+
self.sample_ratios = sample_ratios
|
48 |
+
self._ordered_indices = None
|
49 |
+
self._update_size()
|
50 |
+
|
51 |
+
def __len__(self):
|
52 |
+
return self.cumulative_sizes[-1]
|
53 |
+
|
54 |
+
def __getitem__(self, idx):
|
55 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
56 |
+
sample = self.datasets[dataset_idx][sample_idx]
|
57 |
+
if isinstance(sample, dict):
|
58 |
+
sample["dataset_idx"] = dataset_idx
|
59 |
+
else:
|
60 |
+
sample = sample + (dataset_idx,)
|
61 |
+
return sample
|
62 |
+
|
63 |
+
def _update_size(self):
|
64 |
+
self.cumulative_sizes = self.cumsum(self.datasets)
|
65 |
+
self.real_sizes = [len(d) for d in self.datasets]
|
66 |
+
|
67 |
+
def _get_dataset_and_sample_index(self, idx: int):
|
68 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
69 |
+
if dataset_idx == 0:
|
70 |
+
sample_idx = idx
|
71 |
+
else:
|
72 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
73 |
+
sample_idx = sample_idx % self.real_sizes[dataset_idx]
|
74 |
+
return dataset_idx, sample_idx
|
75 |
+
|
76 |
+
def collater(self, samples, **extra_args):
|
77 |
+
# For now only supports datasets with same underlying collater implementations
|
78 |
+
if samples is not None and len(samples) > 0:
|
79 |
+
if isinstance(samples[0], dict):
|
80 |
+
dataset_idx = samples[0]["dataset_idx"]
|
81 |
+
else:
|
82 |
+
dataset_idx = samples[0][-1]
|
83 |
+
samples = [sample[:-1] for sample in samples]
|
84 |
+
else:
|
85 |
+
dataset_idx = 0
|
86 |
+
|
87 |
+
if hasattr(self.datasets[dataset_idx], "collater"):
|
88 |
+
return self.datasets[dataset_idx].collater(samples, **extra_args)
|
89 |
+
else:
|
90 |
+
return default_collate(samples, **extra_args)
|
91 |
+
|
92 |
+
def size(self, idx: int):
|
93 |
+
"""
|
94 |
+
Return an example's size as a float or tuple.
|
95 |
+
"""
|
96 |
+
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
97 |
+
return self.datasets[dataset_idx].size(sample_idx)
|
98 |
+
|
99 |
+
def num_tokens(self, index: int):
|
100 |
+
return np.max(self.size(index))
|
101 |
+
|
102 |
+
def attr(self, attr: str, index: int):
|
103 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
|
104 |
+
return getattr(self.datasets[dataset_idx], attr, None)
|
105 |
+
|
106 |
+
@property
|
107 |
+
def sizes(self):
|
108 |
+
_dataset_sizes = []
|
109 |
+
for ds in self.datasets:
|
110 |
+
if isinstance(ds.sizes, np.ndarray):
|
111 |
+
_dataset_sizes.append(ds.sizes)
|
112 |
+
else:
|
113 |
+
# Only support underlying dataset with single size array.
|
114 |
+
assert isinstance(ds.sizes, list)
|
115 |
+
_dataset_sizes.append(ds.sizes[0])
|
116 |
+
return np.concatenate(_dataset_sizes)
|
117 |
+
|
118 |
+
@property
|
119 |
+
def supports_prefetch(self):
|
120 |
+
return all(d.supports_prefetch for d in self.datasets)
|
121 |
+
|
122 |
+
def ordered_indices(self):
|
123 |
+
# ordered_indices = []
|
124 |
+
# for i, dataset in enumerate(self.datasets):
|
125 |
+
# indice = dataset.ordered_indices()
|
126 |
+
# ordered_indices.append(indice)
|
127 |
+
if self._ordered_indices is None:
|
128 |
+
# Call the underlying dataset's ordered_indices() here, so that we
|
129 |
+
# get the same random ordering as we would have from using the
|
130 |
+
# underlying sub-datasets directly.
|
131 |
+
self._ordered_indices = [
|
132 |
+
dataset.ordered_indices()
|
133 |
+
for dataset in self.datasets
|
134 |
+
]
|
135 |
+
return np.arange(len(self))
|
136 |
+
|
137 |
+
def prefetch(self, indices):
|
138 |
+
frm = 0
|
139 |
+
for to, ds in zip(self.cumulative_sizes, self.datasets):
|
140 |
+
real_size = len(ds)
|
141 |
+
if getattr(ds, "supports_prefetch", False):
|
142 |
+
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
|
143 |
+
frm = to
|
144 |
+
|
145 |
+
def batch_by_size(
|
146 |
+
self,
|
147 |
+
indices,
|
148 |
+
max_tokens=None,
|
149 |
+
max_sentences=None,
|
150 |
+
required_batch_size_multiple=1,
|
151 |
+
):
|
152 |
+
if not hasattr(self, "max_tokens"):
|
153 |
+
self.max_tokens = max_tokens
|
154 |
+
if not hasattr(self, "max_sentences"):
|
155 |
+
self.max_sentences = max_sentences
|
156 |
+
if not hasattr(self, "required_batch_size_multiple"):
|
157 |
+
self.required_batch_size_multiple = required_batch_size_multiple
|
158 |
+
batch_samplers = []
|
159 |
+
for i, dataset in enumerate(self.datasets):
|
160 |
+
batch_sampler = dataset.batch_by_size(
|
161 |
+
self._ordered_indices[i],
|
162 |
+
max_tokens=max_tokens if self.batch_ratio is None else max_tokens * self.batch_ratio[i],
|
163 |
+
max_sentences=max_sentences,
|
164 |
+
required_batch_size_multiple=required_batch_size_multiple,
|
165 |
+
)
|
166 |
+
if i > 0:
|
167 |
+
for batch in batch_sampler:
|
168 |
+
batch += self.cumulative_sizes[i - 1]
|
169 |
+
if self.sample_ratios[i] != 1.0:
|
170 |
+
batch_sampler = np.array(batch_sampler)
|
171 |
+
batch_sampler = np.random.choice(batch_sampler, int(len(batch_sampler) * self.sample_ratios[i]))
|
172 |
+
batch_sampler = list(batch_sampler)
|
173 |
+
logger.info('Adjust batch by ratio ' + str(self.sample_ratios[i]) + ' and the number of batch is ' + str(int(len(batch_sampler))) + ' for dataset ' + str(i))
|
174 |
+
batch_samplers.extend(batch_sampler)
|
175 |
+
return batch_samplers
|
176 |
+
|
177 |
+
def filter_indices_by_size(self, indices, max_positions):
|
178 |
+
"""
|
179 |
+
Filter each sub-dataset independently, then update the round robin to work
|
180 |
+
on the filtered sub-datasets.
|
181 |
+
"""
|
182 |
+
if not hasattr(self, "max_positions"):
|
183 |
+
self.max_positions = max_positions
|
184 |
+
ignored_some = False
|
185 |
+
for i in range(len(self.datasets)):
|
186 |
+
# ignored = []
|
187 |
+
self._ordered_indices[i], ignored = self.datasets[i].filter_indices_by_size(
|
188 |
+
self._ordered_indices[i], self.max_positions[i]
|
189 |
+
)
|
190 |
+
if len(ignored) > 0:
|
191 |
+
ignored_some = True
|
192 |
+
logger.warning(
|
193 |
+
f"{len(ignored)} samples from {i} have invalid sizes and will be skipped, "
|
194 |
+
f"max_positions={self.max_positions[i]}, first few sample ids={ignored[:10]}"
|
195 |
+
)
|
196 |
+
|
197 |
+
logger.info('update dataset size')
|
198 |
+
self._update_size()
|
199 |
+
|
200 |
+
# Since we are modifying in place the _ordered_indices,
|
201 |
+
# it's not possible anymore to return valid ignored indices.
|
202 |
+
# Hopefully the extra debug information print above should be enough to debug.
|
203 |
+
# Ideally we would receive ignore_invalid_inputs so that we could have
|
204 |
+
# a proper error message.
|
205 |
+
return (np.arange(len(self)), [0] if ignored_some else [])
|
206 |
+
|
207 |
+
@property
|
208 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
209 |
+
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
|
210 |
+
|
211 |
+
def set_epoch(self, epoch):
|
212 |
+
super().set_epoch(epoch)
|
213 |
+
for ds in self.datasets:
|
214 |
+
if hasattr(ds, "set_epoch"):
|
215 |
+
ds.set_epoch(epoch)
|
216 |
+
|
217 |
+
def shuffle_batches(self, batches, seed):
|
218 |
+
logger.info("shuffle batches")
|
219 |
+
new_batches_fromlist = []
|
220 |
+
new_batches_notlist = []
|
221 |
+
new_batches = []
|
222 |
+
with data_utils.numpy_seed(seed):
|
223 |
+
np.random.shuffle(batches)
|
224 |
+
for batch in batches:
|
225 |
+
if isinstance(batch, list):
|
226 |
+
# np.random.shuffle(batch)
|
227 |
+
new_batches_fromlist.append(batch)
|
228 |
+
else:
|
229 |
+
new_batches_notlist.append(batch)
|
230 |
+
logger.info("Get " + str(len(new_batches_fromlist)) + " chunk from speech sides")
|
231 |
+
logger.info("Get " + str(sum([len(batch_list) for batch_list in new_batches_fromlist])) + " batches from speech sides")
|
232 |
+
logger.info("Get " + str(len(new_batches_notlist)) + " batches from text sides")
|
233 |
+
if len(new_batches_fromlist) == 0:
|
234 |
+
return new_batches_notlist
|
235 |
+
st_ratio = int(len(new_batches_notlist) / len(new_batches_fromlist))
|
236 |
+
logger.info("Get st_ratio " + str(st_ratio))
|
237 |
+
last_idx = 0
|
238 |
+
for i in range(len(new_batches_fromlist)):
|
239 |
+
if i == len(new_batches_fromlist) - 1:
|
240 |
+
new_batches_fromlist[i].extend(new_batches_notlist[last_idx:])
|
241 |
+
else:
|
242 |
+
new_batches_fromlist[i].extend(new_batches_notlist[last_idx : last_idx + st_ratio])
|
243 |
+
np.random.shuffle(new_batches_fromlist[i])
|
244 |
+
new_batches.extend(new_batches_fromlist[i])
|
245 |
+
last_idx = last_idx + st_ratio
|
246 |
+
logger.info("Finish shuffle")
|
247 |
+
return new_batches
|
248 |
+
|
249 |
+
def reset_batch_sampler(self):
|
250 |
+
logger.info("reset batch sampler")
|
251 |
+
self._ordered_indices = [
|
252 |
+
self.datasets[i].ordered_indices()
|
253 |
+
for i in range(len(self.datasets))
|
254 |
+
]
|
255 |
+
self.filter_indices_by_size(None, None)
|
256 |
+
|
257 |
+
batch_samplers = self.batch_by_size(
|
258 |
+
None,
|
259 |
+
self.max_tokens,
|
260 |
+
self.max_sentences,
|
261 |
+
self.required_batch_size_multiple
|
262 |
+
)
|
263 |
+
return batch_samplers
|
artst/data/speech_dataset.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
|
5 |
+
# Based on speecht5, fairseq and espnet code bases
|
6 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import itertools
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
from typing import Any, List, Optional, Union
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import librosa
|
20 |
+
from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
|
21 |
+
from fairseq.data import data_utils
|
22 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
def _collate_frames(
|
27 |
+
frames: List[torch.Tensor], is_audio_input: bool = False
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Convert a list of 2D frames into a padded 3D tensor
|
31 |
+
Args:
|
32 |
+
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
33 |
+
length of i-th frame and f_dim is static dimension of features
|
34 |
+
Returns:
|
35 |
+
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
36 |
+
"""
|
37 |
+
max_len = max(frame.size(0) for frame in frames)
|
38 |
+
if is_audio_input:
|
39 |
+
out = frames[0].new_zeros((len(frames), max_len))
|
40 |
+
else:
|
41 |
+
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
42 |
+
for i, v in enumerate(frames):
|
43 |
+
out[i, : v.size(0)] = v
|
44 |
+
return out
|
45 |
+
|
46 |
+
def add_first_frame_and_remove_last_frame(ys):
|
47 |
+
ys_in = torch.cat(
|
48 |
+
[ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1
|
49 |
+
)
|
50 |
+
return ys_in
|
51 |
+
|
52 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
53 |
+
n_long, n_short = 0, 0
|
54 |
+
names, inds, sizes, spk_embeds = [], [], [], []
|
55 |
+
with open(manifest_path) as f:
|
56 |
+
root = f.readline().strip()
|
57 |
+
for ind, line in enumerate(f):
|
58 |
+
items = line.strip().split("\t")
|
59 |
+
assert len(items) == 3, line
|
60 |
+
sz = int(items[1])
|
61 |
+
if min_keep is not None and sz < min_keep:
|
62 |
+
n_short += 1
|
63 |
+
elif max_keep is not None and sz > max_keep:
|
64 |
+
n_long += 1
|
65 |
+
else:
|
66 |
+
names.append(items[0])
|
67 |
+
spk_embeds.append(items[2])
|
68 |
+
inds.append(ind)
|
69 |
+
sizes.append(sz)
|
70 |
+
tot = ind + 1
|
71 |
+
logger.info(
|
72 |
+
(
|
73 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
74 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
75 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
76 |
+
)
|
77 |
+
)
|
78 |
+
return root, names, inds, tot, sizes, spk_embeds
|
79 |
+
|
80 |
+
|
81 |
+
def load_label(label_path, inds, tot):
|
82 |
+
with open(label_path) as f:
|
83 |
+
labels = [line.rstrip() for line in f]
|
84 |
+
assert (
|
85 |
+
len(labels) == tot
|
86 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
87 |
+
labels = [labels[i] for i in inds]
|
88 |
+
return labels
|
89 |
+
|
90 |
+
|
91 |
+
def load_label_offset(label_path, inds, tot):
|
92 |
+
with open(label_path) as f:
|
93 |
+
code_lengths = [len(line.encode("utf-8")) for line in f]
|
94 |
+
assert (
|
95 |
+
len(code_lengths) == tot
|
96 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
97 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
98 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
99 |
+
return offsets
|
100 |
+
|
101 |
+
|
102 |
+
def verify_label_lengths(
|
103 |
+
audio_sizes,
|
104 |
+
audio_rate,
|
105 |
+
label_path,
|
106 |
+
label_rate,
|
107 |
+
inds,
|
108 |
+
tot,
|
109 |
+
tol=0.1, # tolerance in seconds
|
110 |
+
):
|
111 |
+
if label_rate < 0:
|
112 |
+
logger.info(f"{label_path} is sequence label. skipped")
|
113 |
+
return
|
114 |
+
|
115 |
+
with open(label_path) as f:
|
116 |
+
lengths = [len(line.rstrip().split()) for line in f]
|
117 |
+
assert len(lengths) == tot
|
118 |
+
lengths = [lengths[i] for i in inds]
|
119 |
+
num_invalid = 0
|
120 |
+
for i, ind in enumerate(inds):
|
121 |
+
dur_from_audio = audio_sizes[i] / audio_rate
|
122 |
+
dur_from_label = lengths[i] / label_rate
|
123 |
+
if abs(dur_from_audio - dur_from_label) > tol:
|
124 |
+
logger.warning(
|
125 |
+
(
|
126 |
+
f"audio and label duration differ too much "
|
127 |
+
f"(|{dur_from_audio} - {dur_from_label}| > {tol}) "
|
128 |
+
f"in line {ind+1} of {label_path}. Check if `label_rate` "
|
129 |
+
f"is correctly set (currently {label_rate}). "
|
130 |
+
f"num. of samples = {audio_sizes[i]}; "
|
131 |
+
f"label length = {lengths[i]}"
|
132 |
+
)
|
133 |
+
)
|
134 |
+
num_invalid += 1
|
135 |
+
if num_invalid > 0:
|
136 |
+
logger.warning(
|
137 |
+
f"total {num_invalid} (audio, label) pairs with mismatched lengths"
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
def logmelfilterbank(
|
142 |
+
audio,
|
143 |
+
sampling_rate,
|
144 |
+
fft_size=1024,
|
145 |
+
hop_size=256,
|
146 |
+
win_length=None,
|
147 |
+
window="hann",
|
148 |
+
num_mels=80,
|
149 |
+
fmin=80,
|
150 |
+
fmax=7600,
|
151 |
+
eps=1e-10,
|
152 |
+
):
|
153 |
+
"""Compute log-Mel filterbank feature.
|
154 |
+
(https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
|
155 |
+
|
156 |
+
Args:
|
157 |
+
audio (ndarray): Audio signal (T,).
|
158 |
+
sampling_rate (int): Sampling rate.
|
159 |
+
fft_size (int): FFT size.
|
160 |
+
hop_size (int): Hop size.
|
161 |
+
win_length (int): Window length. If set to None, it will be the same as fft_size.
|
162 |
+
window (str): Window function type.
|
163 |
+
num_mels (int): Number of mel basis.
|
164 |
+
fmin (int): Minimum frequency in mel basis calculation.
|
165 |
+
fmax (int): Maximum frequency in mel basis calculation.
|
166 |
+
eps (float): Epsilon value to avoid inf in log calculation.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
ndarray: Log Mel filterbank feature (#frames, num_mels).
|
170 |
+
|
171 |
+
"""
|
172 |
+
# get amplitude spectrogram
|
173 |
+
x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
|
174 |
+
win_length=win_length, window=window, pad_mode="reflect")
|
175 |
+
spc = np.abs(x_stft).T # (#frames, #bins)
|
176 |
+
|
177 |
+
# get mel basis
|
178 |
+
fmin = 0 if fmin is None else fmin
|
179 |
+
fmax = sampling_rate / 2 if fmax is None else fmax
|
180 |
+
mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
181 |
+
|
182 |
+
return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
183 |
+
|
184 |
+
|
185 |
+
class SpeechPretrainDataset(FairseqDataset):
|
186 |
+
def __init__(
|
187 |
+
self,
|
188 |
+
manifest_path: str,
|
189 |
+
sample_rate: float,
|
190 |
+
label_paths: List[str],
|
191 |
+
label_rates: Union[List[float], float], # -1 for sequence labels
|
192 |
+
pad_list: List[str],
|
193 |
+
eos_list: List[str],
|
194 |
+
label_processors: Optional[List[Any]] = None,
|
195 |
+
max_keep_sample_size: Optional[int] = None,
|
196 |
+
min_keep_sample_size: Optional[int] = None,
|
197 |
+
max_sample_size: Optional[int] = None,
|
198 |
+
shuffle: bool = True,
|
199 |
+
pad_audio: bool = False,
|
200 |
+
normalize: bool = False,
|
201 |
+
store_labels: bool = True,
|
202 |
+
random_crop: bool = False,
|
203 |
+
single_target: bool = False,
|
204 |
+
reduction_factor: int = 1,
|
205 |
+
):
|
206 |
+
self.audio_root, self.audio_names, inds, tot, self.sizes, self.spk_embeds = load_audio(
|
207 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
208 |
+
)
|
209 |
+
self.sample_rate = sample_rate
|
210 |
+
self.shuffle = shuffle
|
211 |
+
self.random_crop = random_crop
|
212 |
+
|
213 |
+
self.num_labels = len(label_paths)
|
214 |
+
self.pad_list = pad_list
|
215 |
+
self.eos_list = eos_list
|
216 |
+
self.label_processors = label_processors
|
217 |
+
self.single_target = single_target
|
218 |
+
self.label_rates = (
|
219 |
+
[label_rates for _ in range(len(label_paths))]
|
220 |
+
if isinstance(label_rates, float)
|
221 |
+
else label_rates
|
222 |
+
)
|
223 |
+
self.store_labels = store_labels
|
224 |
+
if store_labels:
|
225 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
226 |
+
else:
|
227 |
+
self.label_paths = label_paths
|
228 |
+
self.label_offsets_list = [
|
229 |
+
load_label_offset(p, inds, tot) for p in label_paths
|
230 |
+
]
|
231 |
+
assert label_processors is None or len(label_processors) == self.num_labels
|
232 |
+
for label_path, label_rate in zip(label_paths, self.label_rates):
|
233 |
+
verify_label_lengths(
|
234 |
+
self.sizes, sample_rate, label_path, label_rate, inds, tot
|
235 |
+
)
|
236 |
+
|
237 |
+
self.max_sample_size = (
|
238 |
+
max_sample_size if max_sample_size is not None else sys.maxsize
|
239 |
+
)
|
240 |
+
self.pad_audio = pad_audio
|
241 |
+
self.normalize = normalize
|
242 |
+
self.reduction_factor = reduction_factor
|
243 |
+
logger.info(
|
244 |
+
f"pad_audio={pad_audio}, random_crop={random_crop}, reduction_factor={reduction_factor}, "
|
245 |
+
f"normalize={normalize}, max_sample_size={self.max_sample_size}"
|
246 |
+
)
|
247 |
+
|
248 |
+
def get_audio(self, index):
|
249 |
+
import soundfile as sf
|
250 |
+
|
251 |
+
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
252 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
253 |
+
wav = torch.from_numpy(wav).float()
|
254 |
+
fbank = logmelfilterbank(
|
255 |
+
wav.view(-1).cpu().numpy(), 16000
|
256 |
+
)
|
257 |
+
fbank = torch.from_numpy(fbank).float()
|
258 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
259 |
+
return wav, fbank
|
260 |
+
|
261 |
+
def get_label(self, index, label_idx):
|
262 |
+
if self.store_labels:
|
263 |
+
label = self.label_list[label_idx][index]
|
264 |
+
else:
|
265 |
+
with open(self.label_paths[label_idx]) as f:
|
266 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
267 |
+
f.seek(offset_s)
|
268 |
+
label = f.read(offset_e - offset_s)
|
269 |
+
|
270 |
+
if self.label_processors is not None:
|
271 |
+
label = self.label_processors[label_idx](label)
|
272 |
+
return label
|
273 |
+
|
274 |
+
def get_labels(self, index):
|
275 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
276 |
+
|
277 |
+
def __getitem__(self, index):
|
278 |
+
wav, fbank = self.get_audio(index)
|
279 |
+
labels = self.get_labels(index)
|
280 |
+
spkembs = get_features_or_waveform(
|
281 |
+
os.path.join(self.audio_root, self.spk_embeds[index])
|
282 |
+
)
|
283 |
+
spkembs = torch.from_numpy(spkembs).float()
|
284 |
+
return {"id": index, "source": wav, "target": fbank, "label_list": labels, 'spkembs': spkembs}
|
285 |
+
|
286 |
+
def __len__(self):
|
287 |
+
return len(self.sizes)
|
288 |
+
|
289 |
+
def crop_to_max_size(self, wav, target_size):
|
290 |
+
size = len(wav)
|
291 |
+
diff = size - target_size
|
292 |
+
if diff <= 0:
|
293 |
+
return wav, 0
|
294 |
+
|
295 |
+
start, end = 0, target_size
|
296 |
+
if self.random_crop:
|
297 |
+
start = np.random.randint(0, diff + 1)
|
298 |
+
end = size - diff + start
|
299 |
+
return wav[start:end], start
|
300 |
+
|
301 |
+
def collater(self, samples):
|
302 |
+
# target = max(sizes) -> random_crop not used
|
303 |
+
# target = max_sample_size -> random_crop used for long
|
304 |
+
samples = [s for s in samples if s["source"] is not None]
|
305 |
+
if len(samples) == 0:
|
306 |
+
return {}
|
307 |
+
|
308 |
+
audios = [s["source"] for s in samples]
|
309 |
+
audio_sizes = [len(s) for s in audios]
|
310 |
+
|
311 |
+
fbanks = [s["target"] for s in samples]
|
312 |
+
fbank_sizes = [len(s) for s in fbanks]
|
313 |
+
|
314 |
+
if self.pad_audio:
|
315 |
+
audio_size = min(max(audio_sizes), self.max_sample_size)
|
316 |
+
else:
|
317 |
+
audio_size = min(min(audio_sizes), self.max_sample_size)
|
318 |
+
collated_audios, padding_mask, audio_starts = self.collater_audio(
|
319 |
+
audios, audio_size
|
320 |
+
)
|
321 |
+
|
322 |
+
collated_fbanks = []
|
323 |
+
collated_audios_size = []
|
324 |
+
for i in range(len(fbanks)):
|
325 |
+
fbank_start = int(audio_starts[i] / (audio_sizes[i] / fbank_sizes[i]))
|
326 |
+
fbank_size = int(audio_size / (audio_sizes[i] / fbank_sizes[i]))
|
327 |
+
fbank_end = min(fbank_start + fbank_size, fbank_sizes[i])
|
328 |
+
collated_fbanks.append(fbanks[i][fbank_start : fbank_end])
|
329 |
+
collated_audios_size.append(audio_size)
|
330 |
+
collated_fbanks_size = [len(s) for s in collated_fbanks]
|
331 |
+
collated_fbanks = _collate_frames(collated_fbanks)
|
332 |
+
collated_fbanks_size = torch.tensor(collated_fbanks_size, dtype=torch.long)
|
333 |
+
|
334 |
+
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
|
335 |
+
if self.reduction_factor > 1:
|
336 |
+
collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
|
337 |
+
collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
|
338 |
+
else:
|
339 |
+
collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
|
340 |
+
|
341 |
+
prev_output_tokens = torch.cat(
|
342 |
+
[collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
|
343 |
+
)
|
344 |
+
|
345 |
+
# make labels for stop prediction
|
346 |
+
labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
|
347 |
+
for i, l in enumerate(fbank_sizes):
|
348 |
+
labels[i, l - 1 :] = 1.0
|
349 |
+
|
350 |
+
spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
|
351 |
+
|
352 |
+
targets_by_label = [
|
353 |
+
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
354 |
+
]
|
355 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(
|
356 |
+
targets_by_label, audio_size, audio_starts
|
357 |
+
)
|
358 |
+
|
359 |
+
net_input = {
|
360 |
+
"source": collated_audios,
|
361 |
+
"padding_mask": padding_mask,
|
362 |
+
"prev_output_tokens": prev_output_tokens,
|
363 |
+
"spkembs": spkembs,
|
364 |
+
"tgt_lengths": collated_fbanks_size_in,
|
365 |
+
}
|
366 |
+
|
367 |
+
batch = {
|
368 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
369 |
+
"net_input": net_input,
|
370 |
+
"labels": labels,
|
371 |
+
"dec_target": collated_fbanks,
|
372 |
+
"dec_target_lengths": collated_fbanks_size,
|
373 |
+
"src_lengths": collated_audios_size,
|
374 |
+
"task_name": 'speech_pretrain',
|
375 |
+
}
|
376 |
+
|
377 |
+
if self.single_target:
|
378 |
+
batch["target_lengths"] = lengths_list[0]
|
379 |
+
batch["ntokens"] = ntokens_list[0]
|
380 |
+
batch["target"] = targets_list[0]
|
381 |
+
else:
|
382 |
+
batch["target_lengths_list"] = lengths_list
|
383 |
+
batch["ntokens_list"] = ntokens_list
|
384 |
+
batch["target_list"] = targets_list
|
385 |
+
return batch
|
386 |
+
|
387 |
+
def collater_audio(self, audios, audio_size):
|
388 |
+
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
389 |
+
padding_mask = (
|
390 |
+
torch.BoolTensor(collated_audios.shape).fill_(False)
|
391 |
+
# if self.pad_audio else None
|
392 |
+
)
|
393 |
+
audio_starts = [0 for _ in audios]
|
394 |
+
for i, audio in enumerate(audios):
|
395 |
+
diff = len(audio) - audio_size
|
396 |
+
if diff == 0:
|
397 |
+
collated_audios[i] = audio
|
398 |
+
elif diff < 0:
|
399 |
+
assert self.pad_audio
|
400 |
+
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
401 |
+
padding_mask[i, diff:] = True
|
402 |
+
else:
|
403 |
+
collated_audios[i], audio_starts[i] = self.crop_to_max_size(
|
404 |
+
audio, audio_size
|
405 |
+
)
|
406 |
+
return collated_audios, padding_mask, audio_starts
|
407 |
+
|
408 |
+
def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad):
|
409 |
+
assert label_rate > 0
|
410 |
+
s2f = label_rate / self.sample_rate
|
411 |
+
frm_starts = [int(round(s * s2f)) for s in audio_starts]
|
412 |
+
frm_size = int(round(audio_size * s2f))
|
413 |
+
if not self.pad_audio:
|
414 |
+
rem_size = [len(t) - s for t, s in zip(targets, frm_starts)]
|
415 |
+
frm_size = min(frm_size, *rem_size)
|
416 |
+
targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)]
|
417 |
+
logger.debug(f"audio_starts={audio_starts}")
|
418 |
+
logger.debug(f"frame_starts={frm_starts}")
|
419 |
+
logger.debug(f"frame_size={frm_size}")
|
420 |
+
|
421 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
422 |
+
ntokens = lengths.sum().item()
|
423 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
424 |
+
return targets, lengths, ntokens
|
425 |
+
|
426 |
+
def collater_seq_label(self, targets, pad):
|
427 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
428 |
+
ntokens = lengths.sum().item()
|
429 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
430 |
+
return targets, lengths, ntokens
|
431 |
+
|
432 |
+
def collater_label(self, targets_by_label, audio_size, audio_starts):
|
433 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
434 |
+
itr = zip(targets_by_label, self.label_rates, self.pad_list)
|
435 |
+
for targets, label_rate, pad in itr:
|
436 |
+
if label_rate == -1.0:
|
437 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
438 |
+
else:
|
439 |
+
targets, lengths, ntokens = self.collater_frm_label(
|
440 |
+
targets, audio_size, audio_starts, label_rate, pad
|
441 |
+
)
|
442 |
+
targets_list.append(targets)
|
443 |
+
lengths_list.append(lengths)
|
444 |
+
ntokens_list.append(ntokens)
|
445 |
+
return targets_list, lengths_list, ntokens_list
|
446 |
+
|
447 |
+
def num_tokens(self, index):
|
448 |
+
return self.size(index)
|
449 |
+
|
450 |
+
def size(self, index):
|
451 |
+
if self.pad_audio:
|
452 |
+
return self.sizes[index]
|
453 |
+
return min(self.sizes[index], self.max_sample_size)
|
454 |
+
|
455 |
+
def ordered_indices(self):
|
456 |
+
if self.shuffle:
|
457 |
+
order = [np.random.permutation(len(self))]
|
458 |
+
else:
|
459 |
+
order = [np.arange(len(self))]
|
460 |
+
|
461 |
+
order.append(self.sizes)
|
462 |
+
return np.lexsort(order)[::-1]
|
463 |
+
|
464 |
+
def postprocess(self, wav, cur_sample_rate):
|
465 |
+
if wav.dim() == 2:
|
466 |
+
wav = wav.mean(-1)
|
467 |
+
assert wav.dim() == 1, wav.dim()
|
468 |
+
|
469 |
+
if cur_sample_rate != self.sample_rate:
|
470 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
471 |
+
|
472 |
+
if self.normalize:
|
473 |
+
with torch.no_grad():
|
474 |
+
wav = F.layer_norm(wav, wav.shape)
|
475 |
+
return wav
|
artst/data/speech_to_class_dataset.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
from typing import Any, List, Optional
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from fairseq.data import data_utils, Dictionary
|
17 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
23 |
+
"""manifest tsv: wav_path, wav_nframe, wav_class
|
24 |
+
|
25 |
+
Args
|
26 |
+
manifest_path: str
|
27 |
+
max_keep: int
|
28 |
+
min_keep: int
|
29 |
+
|
30 |
+
Return
|
31 |
+
root, names, inds, tot, sizes, classes
|
32 |
+
"""
|
33 |
+
n_long, n_short = 0, 0
|
34 |
+
names, inds, sizes, classes = [], [], [], []
|
35 |
+
with open(manifest_path) as f:
|
36 |
+
root = f.readline().strip()
|
37 |
+
for ind, line in enumerate(f):
|
38 |
+
items = line.strip().split("\t")
|
39 |
+
assert len(items) >= 2, line
|
40 |
+
sz = int(items[1])
|
41 |
+
if min_keep is not None and sz < min_keep:
|
42 |
+
n_short += 1
|
43 |
+
elif max_keep is not None and sz > max_keep:
|
44 |
+
n_long += 1
|
45 |
+
else:
|
46 |
+
names.append(items[0])
|
47 |
+
if len(items) > 2:
|
48 |
+
classes.append(items[2])
|
49 |
+
inds.append(ind)
|
50 |
+
sizes.append(sz)
|
51 |
+
tot = ind + 1
|
52 |
+
logger.info(
|
53 |
+
(
|
54 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
55 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
56 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
57 |
+
)
|
58 |
+
)
|
59 |
+
if len(classes) == 0:
|
60 |
+
logger.warn("no classes loaded only if inference")
|
61 |
+
return root, names, inds, tot, sizes, classes
|
62 |
+
|
63 |
+
|
64 |
+
def sample_from_feature(x: np.ndarray, max_segment_length: int = 300):
|
65 |
+
"""Load a segment within 300-400/51200-76800 frames or the corresponding samples from a utterance.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
x (np.ndarray): feature or waveform (frames[, features]), e.g., log mel filter bank or waveform
|
69 |
+
max_segment_length (int, optional): maximum segment length. Defaults to 400.
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
np.ndarray: segmented features
|
73 |
+
"""
|
74 |
+
if len(x) <= max_segment_length:
|
75 |
+
return x
|
76 |
+
start = np.random.randint(0, x.shape[0] - max_segment_length)
|
77 |
+
return x[start: start + max_segment_length]
|
78 |
+
|
79 |
+
|
80 |
+
class SpeechToClassDataset(FairseqDataset):
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
manifest_path: str,
|
84 |
+
sample_rate: float,
|
85 |
+
label_processors: Optional[List[Any]] = None,
|
86 |
+
max_keep_sample_size: Optional[int] = None,
|
87 |
+
min_keep_sample_size: Optional[int] = None,
|
88 |
+
shuffle: bool = True,
|
89 |
+
normalize: bool = False,
|
90 |
+
tgt_dict: Optional[Dictionary] = None,
|
91 |
+
max_length: Optional[int] = None
|
92 |
+
):
|
93 |
+
self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.wav_classes = load_audio(
|
94 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
95 |
+
)
|
96 |
+
self.sample_rate = sample_rate
|
97 |
+
self.shuffle = shuffle
|
98 |
+
|
99 |
+
self.label_processors = label_processors
|
100 |
+
|
101 |
+
self.normalize = normalize
|
102 |
+
self.tgt_dict = tgt_dict
|
103 |
+
self.max_length = max_length
|
104 |
+
logger.info(
|
105 |
+
f"max_length={max_length}, normalize={normalize}"
|
106 |
+
)
|
107 |
+
|
108 |
+
def get_audio(self, index):
|
109 |
+
import soundfile as sf
|
110 |
+
|
111 |
+
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
112 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
113 |
+
if self.max_length is not None:
|
114 |
+
wav = sample_from_feature(wav, self.max_length)
|
115 |
+
wav = torch.from_numpy(wav).float()
|
116 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
117 |
+
return wav
|
118 |
+
|
119 |
+
def get_label(self, index):
|
120 |
+
label = self.wav_classes[index]
|
121 |
+
|
122 |
+
if self.label_processors is not None:
|
123 |
+
label = self.label_processors(label)
|
124 |
+
return label
|
125 |
+
|
126 |
+
def __getitem__(self, index):
|
127 |
+
wav = self.get_audio(index)
|
128 |
+
label = None
|
129 |
+
if len(self.wav_classes) == len(self.audio_names):
|
130 |
+
label = self.get_label(index)
|
131 |
+
return {"id": index, "source": wav, "label": label}
|
132 |
+
|
133 |
+
def __len__(self):
|
134 |
+
return len(self.wav_sizes)
|
135 |
+
|
136 |
+
def collater(self, samples):
|
137 |
+
samples = [s for s in samples if s["source"] is not None]
|
138 |
+
if len(samples) == 0:
|
139 |
+
return {}
|
140 |
+
|
141 |
+
audios = [s["source"] for s in samples]
|
142 |
+
audio_sizes = [len(s) for s in audios]
|
143 |
+
|
144 |
+
audio_size = max(audio_sizes)
|
145 |
+
collated_audios, padding_mask = self.collater_audio(
|
146 |
+
audios, audio_size
|
147 |
+
)
|
148 |
+
|
149 |
+
decoder_label = None
|
150 |
+
decoder_target = None
|
151 |
+
decoder_target_lengths = None
|
152 |
+
if samples[0]["label"] is not None:
|
153 |
+
targets_by_label = [
|
154 |
+
[s["label"] for s in samples]
|
155 |
+
]
|
156 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label)
|
157 |
+
|
158 |
+
decoder_label = [
|
159 |
+
(targets_list[0][i, :lengths_list[0][i]]).long()
|
160 |
+
for i in range(targets_list[0].size(0))
|
161 |
+
]
|
162 |
+
|
163 |
+
decoder_target = data_utils.collate_tokens(
|
164 |
+
decoder_label,
|
165 |
+
self.tgt_dict.pad(),
|
166 |
+
self.tgt_dict.eos(),
|
167 |
+
left_pad=False,
|
168 |
+
move_eos_to_beginning=False,
|
169 |
+
)
|
170 |
+
decoder_target_lengths = torch.tensor(
|
171 |
+
[x.size(0) for x in decoder_label], dtype=torch.long
|
172 |
+
)
|
173 |
+
prev_output_tokens = data_utils.collate_tokens(
|
174 |
+
[torch.LongTensor([-1]) for _ in samples],
|
175 |
+
self.tgt_dict.pad(),
|
176 |
+
self.tgt_dict.eos(),
|
177 |
+
left_pad=False,
|
178 |
+
move_eos_to_beginning=True,
|
179 |
+
)
|
180 |
+
|
181 |
+
net_input = {
|
182 |
+
"source": collated_audios,
|
183 |
+
"padding_mask": padding_mask,
|
184 |
+
"prev_output_tokens": prev_output_tokens,
|
185 |
+
"task_name": "s2c",
|
186 |
+
}
|
187 |
+
batch = {
|
188 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
189 |
+
"net_input": net_input,
|
190 |
+
"target": decoder_target,
|
191 |
+
"target_lengths": decoder_target_lengths,
|
192 |
+
"task_name": "s2c",
|
193 |
+
"ntokens": len(samples),
|
194 |
+
}
|
195 |
+
|
196 |
+
return batch
|
197 |
+
|
198 |
+
def collater_audio(self, audios, audio_size):
|
199 |
+
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
200 |
+
padding_mask = (
|
201 |
+
torch.BoolTensor(collated_audios.shape).fill_(False)
|
202 |
+
)
|
203 |
+
for i, audio in enumerate(audios):
|
204 |
+
diff = len(audio) - audio_size
|
205 |
+
if diff == 0:
|
206 |
+
collated_audios[i] = audio
|
207 |
+
elif diff < 0:
|
208 |
+
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
209 |
+
padding_mask[i, diff:] = True
|
210 |
+
else:
|
211 |
+
raise Exception("Diff should not be larger than 0")
|
212 |
+
return collated_audios, padding_mask
|
213 |
+
|
214 |
+
def collater_seq_label(self, targets, pad):
|
215 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
216 |
+
ntokens = lengths.sum().item()
|
217 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
218 |
+
return targets, lengths, ntokens
|
219 |
+
|
220 |
+
def collater_label(self, targets_by_label):
|
221 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
222 |
+
itr = zip(targets_by_label, [self.tgt_dict.pad()])
|
223 |
+
for targets, pad in itr:
|
224 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
225 |
+
targets_list.append(targets)
|
226 |
+
lengths_list.append(lengths)
|
227 |
+
ntokens_list.append(ntokens)
|
228 |
+
return targets_list, lengths_list, ntokens_list
|
229 |
+
|
230 |
+
def num_tokens(self, index):
|
231 |
+
return self.size(index)
|
232 |
+
|
233 |
+
def size(self, index):
|
234 |
+
return self.wav_sizes[index]
|
235 |
+
|
236 |
+
@property
|
237 |
+
def sizes(self):
|
238 |
+
return np.array(self.wav_sizes)
|
239 |
+
|
240 |
+
def ordered_indices(self):
|
241 |
+
if self.shuffle:
|
242 |
+
order = [np.random.permutation(len(self))]
|
243 |
+
else:
|
244 |
+
order = [np.arange(len(self))]
|
245 |
+
|
246 |
+
order.append(self.wav_sizes)
|
247 |
+
return np.lexsort(order)[::-1]
|
248 |
+
|
249 |
+
def postprocess(self, wav, cur_sample_rate):
|
250 |
+
if wav.dim() == 2:
|
251 |
+
wav = wav.mean(-1)
|
252 |
+
assert wav.dim() == 1, wav.dim()
|
253 |
+
|
254 |
+
if cur_sample_rate != self.sample_rate:
|
255 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
256 |
+
|
257 |
+
if self.normalize:
|
258 |
+
with torch.no_grad():
|
259 |
+
wav = F.layer_norm(wav, wav.shape)
|
260 |
+
return wav
|
artst/data/speech_to_speech_dataset.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
from typing import Any, List, Optional
|
11 |
+
|
12 |
+
import librosa
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
def _collate_frames(
|
21 |
+
frames: List[torch.Tensor], is_audio_input: bool = False
|
22 |
+
):
|
23 |
+
"""
|
24 |
+
Convert a list of 2D frames into a padded 3D tensor
|
25 |
+
Args:
|
26 |
+
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
27 |
+
length of i-th frame and f_dim is static dimension of features
|
28 |
+
Returns:
|
29 |
+
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
30 |
+
"""
|
31 |
+
max_len = max(frame.size(0) for frame in frames)
|
32 |
+
if is_audio_input:
|
33 |
+
out = frames[0].new_zeros((len(frames), max_len))
|
34 |
+
else:
|
35 |
+
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
36 |
+
for i, v in enumerate(frames):
|
37 |
+
out[i, : v.size(0)] = v
|
38 |
+
return out
|
39 |
+
|
40 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
41 |
+
"""manifest tsv: src_wav, src_nframe, tgt_wav, tgt_nframe, tgt_spkemb"""
|
42 |
+
n_long, n_short = 0, 0
|
43 |
+
src_names, tgt_names, inds, sizes, tgt_sizes, spk_embeds = [], [], [], [], [], []
|
44 |
+
with open(manifest_path) as f:
|
45 |
+
root = f.readline().strip()
|
46 |
+
for ind, line in enumerate(f):
|
47 |
+
items = line.strip().split("\t")
|
48 |
+
assert len(items) >= 2, line
|
49 |
+
sz = int(items[1])
|
50 |
+
if min_keep is not None and sz < min_keep:
|
51 |
+
n_short += 1
|
52 |
+
elif max_keep is not None and sz > max_keep:
|
53 |
+
n_long += 1
|
54 |
+
else:
|
55 |
+
src_names.append(items[0])
|
56 |
+
tgt_names.append(items[2])
|
57 |
+
tgt_sizes.append(items[3])
|
58 |
+
spk_embeds.append(items[4])
|
59 |
+
inds.append(ind)
|
60 |
+
sizes.append(sz)
|
61 |
+
tot = ind + 1
|
62 |
+
logger.info(
|
63 |
+
(
|
64 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
65 |
+
f"loaded {len(src_names)}, skipped {n_short} short and {n_long} long, "
|
66 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
67 |
+
)
|
68 |
+
)
|
69 |
+
return root, src_names, inds, tot, sizes, tgt_names, tgt_sizes, spk_embeds
|
70 |
+
|
71 |
+
|
72 |
+
def logmelfilterbank(
|
73 |
+
audio,
|
74 |
+
sampling_rate,
|
75 |
+
fft_size=1024,
|
76 |
+
hop_size=256,
|
77 |
+
win_length=None,
|
78 |
+
window="hann",
|
79 |
+
num_mels=80,
|
80 |
+
fmin=80,
|
81 |
+
fmax=7600,
|
82 |
+
eps=1e-10,
|
83 |
+
):
|
84 |
+
"""Compute log-Mel filterbank feature.
|
85 |
+
(https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
|
86 |
+
|
87 |
+
Args:
|
88 |
+
audio (ndarray): Audio signal (T,).
|
89 |
+
sampling_rate (int): Sampling rate.
|
90 |
+
fft_size (int): FFT size.
|
91 |
+
hop_size (int): Hop size.
|
92 |
+
win_length (int): Window length. If set to None, it will be the same as fft_size.
|
93 |
+
window (str): Window function type.
|
94 |
+
num_mels (int): Number of mel basis.
|
95 |
+
fmin (int): Minimum frequency in mel basis calculation.
|
96 |
+
fmax (int): Maximum frequency in mel basis calculation.
|
97 |
+
eps (float): Epsilon value to avoid inf in log calculation.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
ndarray: Log Mel filterbank feature (#frames, num_mels).
|
101 |
+
|
102 |
+
"""
|
103 |
+
# get amplitude spectrogram
|
104 |
+
x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
|
105 |
+
win_length=win_length, window=window, pad_mode="reflect")
|
106 |
+
spc = np.abs(x_stft).T # (#frames, #bins)
|
107 |
+
|
108 |
+
# get mel basis
|
109 |
+
fmin = 0 if fmin is None else fmin
|
110 |
+
fmax = sampling_rate / 2 if fmax is None else fmax
|
111 |
+
mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
112 |
+
|
113 |
+
return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
114 |
+
|
115 |
+
|
116 |
+
class SpeechToSpeechDataset(FairseqDataset):
|
117 |
+
def __init__(
|
118 |
+
self,
|
119 |
+
manifest_path: str,
|
120 |
+
sample_rate: float,
|
121 |
+
max_keep_sample_size: Optional[int] = None,
|
122 |
+
min_keep_sample_size: Optional[int] = None,
|
123 |
+
shuffle: bool = True,
|
124 |
+
normalize: bool = False,
|
125 |
+
reduction_factor: int = 1,
|
126 |
+
):
|
127 |
+
self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.tgt_audios, self.tgt_sizes, self.tgt_spkembs = load_audio(
|
128 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
129 |
+
)
|
130 |
+
self.sample_rate = sample_rate
|
131 |
+
self.shuffle = shuffle
|
132 |
+
|
133 |
+
self.normalize = normalize
|
134 |
+
self.reduction_factor = reduction_factor
|
135 |
+
logger.info(
|
136 |
+
f"reduction_factor={reduction_factor}, normalize={normalize}"
|
137 |
+
)
|
138 |
+
|
139 |
+
def get_audio(self, index):
|
140 |
+
import soundfile as sf
|
141 |
+
|
142 |
+
wav_fbank = []
|
143 |
+
for name in [self.audio_names[index], self.tgt_audios[index]]:
|
144 |
+
wav_path = os.path.join(self.audio_root, name)
|
145 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
146 |
+
wav = torch.from_numpy(wav).float()
|
147 |
+
fbank = logmelfilterbank(
|
148 |
+
wav.view(-1).cpu().numpy(), 16000
|
149 |
+
)
|
150 |
+
fbank = torch.from_numpy(fbank).float()
|
151 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
152 |
+
wav_fbank.append(wav)
|
153 |
+
wav_fbank.append(fbank)
|
154 |
+
src_wav, src_fbank, tgt_wav, tgt_fbank = wav_fbank
|
155 |
+
return src_wav, src_fbank, tgt_wav, tgt_fbank
|
156 |
+
|
157 |
+
def __getitem__(self, index):
|
158 |
+
src_wav, src_fbank, tgt_wav, tgt_fbank = self.get_audio(index)
|
159 |
+
spkembs = np.load(os.path.join(self.audio_root, self.tgt_spkembs[index]))
|
160 |
+
spkembs = torch.from_numpy(spkembs).float()
|
161 |
+
name = self.audio_names[index].replace("/", ".").replace(".wav", "") + "-" + self.tgt_audios[index].replace("/", ".").replace(".wav", "") + ".wav"
|
162 |
+
return {"id": index, "source": src_wav, "target": tgt_fbank, "spkembs": spkembs, "audio_name": name, "tgt_name": self.tgt_audios[index]}
|
163 |
+
|
164 |
+
def __len__(self):
|
165 |
+
return len(self.wav_sizes)
|
166 |
+
|
167 |
+
def collater(self, samples):
|
168 |
+
samples = [s for s in samples if s["source"] is not None]
|
169 |
+
if len(samples) == 0:
|
170 |
+
return {}
|
171 |
+
|
172 |
+
audios = [s["source"] for s in samples]
|
173 |
+
audio_sizes = [len(s) for s in audios]
|
174 |
+
|
175 |
+
audio_size = max(audio_sizes)
|
176 |
+
collated_audios, padding_mask = self.collater_audio(
|
177 |
+
audios, audio_size
|
178 |
+
)
|
179 |
+
|
180 |
+
fbanks = [s["target"] for s in samples]
|
181 |
+
fbank_sizes = [len(s) for s in fbanks]
|
182 |
+
|
183 |
+
collated_fbanks = _collate_frames(fbanks)
|
184 |
+
collated_fbanks_size = torch.tensor(fbank_sizes, dtype=torch.long)
|
185 |
+
|
186 |
+
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
|
187 |
+
if self.reduction_factor > 1:
|
188 |
+
collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
|
189 |
+
collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
|
190 |
+
else:
|
191 |
+
collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
|
192 |
+
|
193 |
+
prev_output_tokens = torch.cat(
|
194 |
+
[collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
|
195 |
+
)
|
196 |
+
|
197 |
+
# make labels for stop prediction
|
198 |
+
labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
|
199 |
+
for i, l in enumerate(fbank_sizes):
|
200 |
+
labels[i, l - 1 :] = 1.0
|
201 |
+
|
202 |
+
spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
|
203 |
+
|
204 |
+
net_input = {
|
205 |
+
"source": collated_audios,
|
206 |
+
"padding_mask": padding_mask,
|
207 |
+
"prev_output_tokens": prev_output_tokens,
|
208 |
+
"tgt_lengths": collated_fbanks_size_in,
|
209 |
+
"spkembs": spkembs,
|
210 |
+
"task_name": "s2s",
|
211 |
+
}
|
212 |
+
batch = {
|
213 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
214 |
+
"name": [s["audio_name"] for s in samples],
|
215 |
+
"tgt_name": [s["tgt_name"] for s in samples],
|
216 |
+
"net_input": net_input,
|
217 |
+
"labels": labels,
|
218 |
+
"dec_target": collated_fbanks,
|
219 |
+
"dec_target_lengths": collated_fbanks_size,
|
220 |
+
"src_lengths": torch.LongTensor(audio_sizes),
|
221 |
+
"task_name": "s2s",
|
222 |
+
"ntokens": sum(audio_sizes),
|
223 |
+
"target": collated_fbanks,
|
224 |
+
}
|
225 |
+
|
226 |
+
return batch
|
227 |
+
|
228 |
+
def collater_audio(self, audios, audio_size):
|
229 |
+
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
230 |
+
padding_mask = (
|
231 |
+
torch.BoolTensor(collated_audios.shape).fill_(False)
|
232 |
+
)
|
233 |
+
for i, audio in enumerate(audios):
|
234 |
+
diff = len(audio) - audio_size
|
235 |
+
if diff == 0:
|
236 |
+
collated_audios[i] = audio
|
237 |
+
elif diff < 0:
|
238 |
+
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
239 |
+
padding_mask[i, diff:] = True
|
240 |
+
else:
|
241 |
+
raise Exception("Diff should not be larger than 0")
|
242 |
+
return collated_audios, padding_mask
|
243 |
+
|
244 |
+
|
245 |
+
def num_tokens(self, index):
|
246 |
+
return self.wav_sizes[index]
|
247 |
+
|
248 |
+
def size(self, index):
|
249 |
+
return self.wav_sizes[index], self.tgt_sizes[index]
|
250 |
+
|
251 |
+
@property
|
252 |
+
def sizes(self):
|
253 |
+
return np.array(self.wav_sizes)
|
254 |
+
|
255 |
+
@property
|
256 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
257 |
+
"""No cache dataset if dataset is large-scale. Cache dataset for small dataset."""
|
258 |
+
return True
|
259 |
+
|
260 |
+
def ordered_indices(self):
|
261 |
+
if self.shuffle:
|
262 |
+
order = [np.random.permutation(len(self))]
|
263 |
+
else:
|
264 |
+
order = [np.arange(len(self))]
|
265 |
+
|
266 |
+
order.append(self.wav_sizes)
|
267 |
+
return np.lexsort(order)[::-1]
|
268 |
+
|
269 |
+
def postprocess(self, wav, cur_sample_rate):
|
270 |
+
if wav.dim() == 2:
|
271 |
+
wav = wav.mean(-1)
|
272 |
+
assert wav.dim() == 1, wav.dim()
|
273 |
+
|
274 |
+
if cur_sample_rate != self.sample_rate:
|
275 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
276 |
+
|
277 |
+
if self.normalize:
|
278 |
+
with torch.no_grad():
|
279 |
+
wav = F.layer_norm(wav, wav.shape)
|
280 |
+
return wav
|
artst/data/speech_to_text_dataset.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import itertools
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import mmap
|
12 |
+
from typing import Any, List, Optional
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
torch.set_printoptions(profile="full")
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from fairseq.data import data_utils, Dictionary
|
20 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
26 |
+
n_long, n_short = 0, 0
|
27 |
+
names, inds, sizes = [], [], []
|
28 |
+
with open(manifest_path) as f:
|
29 |
+
root = f.readline().strip()
|
30 |
+
for ind, line in enumerate(f):
|
31 |
+
items = line.strip().split("\t")
|
32 |
+
assert len(items) >= 2, line
|
33 |
+
sz = int(items[1])
|
34 |
+
if min_keep is not None and sz < min_keep:
|
35 |
+
n_short += 1
|
36 |
+
elif max_keep is not None and sz > max_keep:
|
37 |
+
n_long += 1
|
38 |
+
else:
|
39 |
+
names.append(items[0])
|
40 |
+
inds.append(ind)
|
41 |
+
sizes.append(sz)
|
42 |
+
tot = ind + 1
|
43 |
+
logger.info(
|
44 |
+
(
|
45 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
46 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
47 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
48 |
+
)
|
49 |
+
)
|
50 |
+
return root, names, inds, tot, sizes
|
51 |
+
|
52 |
+
|
53 |
+
def load_label(label_path, inds, tot):
|
54 |
+
with open(label_path) as f:
|
55 |
+
labels = [line.rstrip() for line in f]
|
56 |
+
assert (
|
57 |
+
len(labels) == tot
|
58 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
59 |
+
labels = [labels[i] for i in inds]
|
60 |
+
return labels
|
61 |
+
|
62 |
+
|
63 |
+
def load_label_offset(label_path, inds, tot):
|
64 |
+
with open(label_path) as f:
|
65 |
+
# Hawau:
|
66 |
+
# changed line length reading as it's incorrect
|
67 |
+
code_lengths = [len(line.encode("utf-8")) for line in f] #original
|
68 |
+
# code_lengths = [len(line) for line in f] #fix
|
69 |
+
assert (
|
70 |
+
len(code_lengths) == tot
|
71 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
72 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
73 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
74 |
+
return offsets
|
75 |
+
|
76 |
+
|
77 |
+
class SpeechToTextDataset(FairseqDataset):
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
manifest_path: str,
|
81 |
+
sample_rate: float,
|
82 |
+
label_paths: List[str],
|
83 |
+
label_processors: Optional[List[Any]] = None,
|
84 |
+
max_keep_sample_size: Optional[int] = None,
|
85 |
+
min_keep_sample_size: Optional[int] = None,
|
86 |
+
shuffle: bool = True,
|
87 |
+
normalize: bool = False,
|
88 |
+
store_labels: bool = True,
|
89 |
+
tgt_dict: Optional[Dictionary] = None,
|
90 |
+
tokenizer = None,
|
91 |
+
):
|
92 |
+
self.audio_root, self.audio_names, inds, tot, self.wav_sizes = load_audio(
|
93 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
94 |
+
)
|
95 |
+
|
96 |
+
self.sample_rate = sample_rate
|
97 |
+
self.shuffle = shuffle
|
98 |
+
self.tgt_dict = tgt_dict
|
99 |
+
self.tokenizer = tokenizer
|
100 |
+
|
101 |
+
self.num_labels = len(label_paths)
|
102 |
+
self.label_processors = label_processors
|
103 |
+
self.store_labels = store_labels
|
104 |
+
|
105 |
+
if store_labels:
|
106 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
107 |
+
logger.info(f"label_list: {self.label_list}")
|
108 |
+
else:
|
109 |
+
self.label_paths = label_paths
|
110 |
+
self.label_offsets_list = [
|
111 |
+
load_label_offset(p, inds, tot) for p in label_paths
|
112 |
+
]
|
113 |
+
# logger.info(f"label_offsets_list: {self.label_offsets_list}")
|
114 |
+
assert label_processors is None or len(label_processors) == self.num_labels
|
115 |
+
|
116 |
+
self.normalize = normalize
|
117 |
+
logger.info(
|
118 |
+
f"normalize={normalize}"
|
119 |
+
)
|
120 |
+
|
121 |
+
def get_audio(self, index):
|
122 |
+
import soundfile as sf
|
123 |
+
# Hawau:
|
124 |
+
# logger.info(f"loaded_audio: {self.audio_names[index]}")
|
125 |
+
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
126 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
127 |
+
wav = torch.from_numpy(wav).float()
|
128 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
129 |
+
return wav
|
130 |
+
|
131 |
+
def get_label(self, index, label_idx):
|
132 |
+
if self.store_labels:
|
133 |
+
label = self.label_list[label_idx][index]
|
134 |
+
else:
|
135 |
+
# list slicing method
|
136 |
+
# with open(self.label_paths[label_idx]) as f:
|
137 |
+
# offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
138 |
+
# # Hawau:
|
139 |
+
# # f.seek(offset_s)
|
140 |
+
# # label = f.read(offset_e - offset_s)
|
141 |
+
# label = f.read()[offset_s : offset_e]
|
142 |
+
# Hawau:
|
143 |
+
# mmap method
|
144 |
+
with open(self.label_paths[label_idx], encoding='utf-8') as f:
|
145 |
+
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
|
146 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
147 |
+
label = mm[offset_s:offset_e].decode("utf-8")
|
148 |
+
|
149 |
+
|
150 |
+
# Hawau:
|
151 |
+
# logger.info(f"loaded_label: {label}")
|
152 |
+
if self.tokenizer is not None:
|
153 |
+
label = self.tokenizer.encode(label)
|
154 |
+
|
155 |
+
if self.label_processors is not None:
|
156 |
+
label = self.label_processors[label_idx](label)
|
157 |
+
# logger.info(f"processed_label: {label}")
|
158 |
+
return label
|
159 |
+
|
160 |
+
def get_labels(self, index):
|
161 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
162 |
+
|
163 |
+
def __getitem__(self, index):
|
164 |
+
wav = self.get_audio(index)
|
165 |
+
labels = self.get_labels(index)
|
166 |
+
return {"id": index, "source": wav, "label_list": labels}
|
167 |
+
|
168 |
+
def __len__(self):
|
169 |
+
return len(self.wav_sizes)
|
170 |
+
|
171 |
+
def collater(self, samples):
|
172 |
+
samples = [s for s in samples if s["source"] is not None]
|
173 |
+
if len(samples) == 0:
|
174 |
+
return {}
|
175 |
+
|
176 |
+
audios = [s["source"] for s in samples]
|
177 |
+
audio_sizes = [len(s) for s in audios]
|
178 |
+
|
179 |
+
audio_size = max(audio_sizes)
|
180 |
+
collated_audios, padding_mask = self.collater_audio(
|
181 |
+
audios, audio_size
|
182 |
+
)
|
183 |
+
|
184 |
+
targets_by_label = [
|
185 |
+
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
|
186 |
+
]
|
187 |
+
targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label)
|
188 |
+
|
189 |
+
# Hawau:
|
190 |
+
# logger.info(f'targets_list: {targets_list}')
|
191 |
+
|
192 |
+
|
193 |
+
decoder_label = [
|
194 |
+
torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
|
195 |
+
for i in range(targets_list[0].size(0))
|
196 |
+
]
|
197 |
+
|
198 |
+
decoder_target = data_utils.collate_tokens(
|
199 |
+
decoder_label,
|
200 |
+
self.tgt_dict.pad(),
|
201 |
+
self.tgt_dict.eos(),
|
202 |
+
left_pad=False,
|
203 |
+
move_eos_to_beginning=False,
|
204 |
+
)
|
205 |
+
decoder_target_lengths = torch.tensor(
|
206 |
+
[x.size(0) for x in decoder_label], dtype=torch.long
|
207 |
+
)
|
208 |
+
prev_output_tokens = data_utils.collate_tokens(
|
209 |
+
decoder_label,
|
210 |
+
self.tgt_dict.pad(),
|
211 |
+
self.tgt_dict.eos(),
|
212 |
+
left_pad=False,
|
213 |
+
move_eos_to_beginning=True,
|
214 |
+
)
|
215 |
+
|
216 |
+
net_input = {
|
217 |
+
"source": collated_audios,
|
218 |
+
"padding_mask": padding_mask,
|
219 |
+
"prev_output_tokens": prev_output_tokens,
|
220 |
+
"task_name": "s2t",
|
221 |
+
}
|
222 |
+
batch = {
|
223 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
224 |
+
"net_input": net_input,
|
225 |
+
"target": decoder_target,
|
226 |
+
"target_lengths": decoder_target_lengths,
|
227 |
+
"task_name": "s2t",
|
228 |
+
"ntokens": ntokens_list[0]
|
229 |
+
}
|
230 |
+
|
231 |
+
return batch
|
232 |
+
|
233 |
+
def collater_audio(self, audios, audio_size):
|
234 |
+
collated_audios = audios[0].new_zeros(len(audios), audio_size)
|
235 |
+
padding_mask = (
|
236 |
+
torch.BoolTensor(collated_audios.shape).fill_(False)
|
237 |
+
)
|
238 |
+
for i, audio in enumerate(audios):
|
239 |
+
diff = len(audio) - audio_size
|
240 |
+
if diff == 0:
|
241 |
+
collated_audios[i] = audio
|
242 |
+
elif diff < 0:
|
243 |
+
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
|
244 |
+
padding_mask[i, diff:] = True
|
245 |
+
else:
|
246 |
+
raise Exception("Diff should not be larger than 0")
|
247 |
+
return collated_audios, padding_mask
|
248 |
+
|
249 |
+
def collater_seq_label(self, targets, pad):
|
250 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
251 |
+
ntokens = lengths.sum().item()
|
252 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
253 |
+
return targets, lengths, ntokens
|
254 |
+
|
255 |
+
def collater_label(self, targets_by_label):
|
256 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
257 |
+
itr = zip(targets_by_label, [self.tgt_dict.pad()])
|
258 |
+
|
259 |
+
for targets, pad in itr:
|
260 |
+
# Hawau:
|
261 |
+
# logger.info(f'targets: {targets}')
|
262 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
263 |
+
targets_list.append(targets)
|
264 |
+
lengths_list.append(lengths)
|
265 |
+
ntokens_list.append(ntokens)
|
266 |
+
return targets_list, lengths_list, ntokens_list
|
267 |
+
|
268 |
+
def num_tokens(self, index):
|
269 |
+
return self.size(index)
|
270 |
+
|
271 |
+
def size(self, index):
|
272 |
+
return self.wav_sizes[index]
|
273 |
+
|
274 |
+
@property
|
275 |
+
def sizes(self):
|
276 |
+
return np.array(self.wav_sizes)
|
277 |
+
|
278 |
+
def ordered_indices(self):
|
279 |
+
if self.shuffle:
|
280 |
+
order = [np.random.permutation(len(self))]
|
281 |
+
else:
|
282 |
+
order = [np.arange(len(self))]
|
283 |
+
|
284 |
+
order.append(self.wav_sizes)
|
285 |
+
return np.lexsort(order)[::-1]
|
286 |
+
|
287 |
+
def postprocess(self, wav, cur_sample_rate):
|
288 |
+
if wav.dim() == 2:
|
289 |
+
wav = wav.mean(-1)
|
290 |
+
assert wav.dim() == 1, wav.dim()
|
291 |
+
|
292 |
+
if cur_sample_rate != self.sample_rate:
|
293 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
294 |
+
|
295 |
+
if self.normalize:
|
296 |
+
with torch.no_grad():
|
297 |
+
wav = F.layer_norm(wav, wav.shape)
|
298 |
+
return wav
|
artst/data/text_dataset.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from fairseq.data import FairseqDataset, data_utils
|
14 |
+
|
15 |
+
|
16 |
+
def collate(
|
17 |
+
samples,
|
18 |
+
pad_idx,
|
19 |
+
eos_idx,
|
20 |
+
vocab,
|
21 |
+
left_pad_source=False,
|
22 |
+
left_pad_target=False,
|
23 |
+
input_feeding=True,
|
24 |
+
pad_to_length=None,
|
25 |
+
):
|
26 |
+
assert input_feeding
|
27 |
+
if len(samples) == 0:
|
28 |
+
return {}
|
29 |
+
|
30 |
+
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
|
31 |
+
return data_utils.collate_tokens(
|
32 |
+
[s[key] for s in samples],
|
33 |
+
pad_idx,
|
34 |
+
eos_idx=None, # use eos_idx of each sample instead of vocab.eos()
|
35 |
+
left_pad=left_pad,
|
36 |
+
move_eos_to_beginning=move_eos_to_beginning,
|
37 |
+
pad_to_length=pad_to_length,
|
38 |
+
)
|
39 |
+
|
40 |
+
id = torch.LongTensor([s["id"] for s in samples])
|
41 |
+
src_tokens = merge(
|
42 |
+
"source",
|
43 |
+
left_pad=left_pad_source,
|
44 |
+
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None,
|
45 |
+
)
|
46 |
+
# sort by descending source length
|
47 |
+
src_lengths = torch.LongTensor([s["source"].numel() for s in samples])
|
48 |
+
src_lengths, sort_order = src_lengths.sort(descending=True)
|
49 |
+
id = id.index_select(0, sort_order)
|
50 |
+
src_tokens = src_tokens.index_select(0, sort_order)
|
51 |
+
|
52 |
+
prev_output_tokens = None
|
53 |
+
target = None
|
54 |
+
if samples[0].get("target", None) is not None:
|
55 |
+
target = merge(
|
56 |
+
"target",
|
57 |
+
left_pad=left_pad_target,
|
58 |
+
pad_to_length=pad_to_length["target"]
|
59 |
+
if pad_to_length is not None
|
60 |
+
else None,
|
61 |
+
)
|
62 |
+
target = target.index_select(0, sort_order)
|
63 |
+
ntokens = sum(len(s["target"]) for s in samples)
|
64 |
+
|
65 |
+
if input_feeding:
|
66 |
+
# we create a shifted version of targets for feeding the
|
67 |
+
# previous output token(s) into the next decoder step
|
68 |
+
prev_output_tokens = merge(
|
69 |
+
"target",
|
70 |
+
left_pad=left_pad_target,
|
71 |
+
move_eos_to_beginning=True,
|
72 |
+
pad_to_length=pad_to_length["target"]
|
73 |
+
if pad_to_length is not None
|
74 |
+
else None,
|
75 |
+
)
|
76 |
+
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
77 |
+
else:
|
78 |
+
ntokens = sum(len(s["source"]) for s in samples)
|
79 |
+
|
80 |
+
batch = {
|
81 |
+
"id": id,
|
82 |
+
"ntokens": ntokens,
|
83 |
+
"net_input": {
|
84 |
+
"src_tokens": src_tokens,
|
85 |
+
"src_lengths": src_lengths,
|
86 |
+
},
|
87 |
+
"target": target,
|
88 |
+
"nsentences": samples[0]["source"].size(0),
|
89 |
+
"sort_order": sort_order,
|
90 |
+
"task_name": 'text_pretrain',
|
91 |
+
}
|
92 |
+
if prev_output_tokens is not None:
|
93 |
+
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
94 |
+
|
95 |
+
return batch
|
96 |
+
|
97 |
+
|
98 |
+
class TextPretrainDataset(FairseqDataset):
|
99 |
+
"""
|
100 |
+
A wrapper around TokenBlockDataset for BART dataset.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
dataset (TokenBlockDataset): dataset to wrap
|
104 |
+
sizes (List[int]): sentence lengths
|
105 |
+
vocab (~fairseq.data.Dictionary): vocabulary
|
106 |
+
mask_idx (int): dictionary index used for masked token
|
107 |
+
mask_whole_words: only mask whole words. This should be a byte mask
|
108 |
+
over vocab indices, indicating whether it is the beginning of a
|
109 |
+
word. We will extend any mask to encompass the whole word.
|
110 |
+
shuffle (bool, optional): shuffle the elements before batching.
|
111 |
+
Default: ``True``
|
112 |
+
seed: Seed for random number generator for reproducibility.
|
113 |
+
args: argparse arguments.
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
dataset,
|
119 |
+
sizes,
|
120 |
+
vocab,
|
121 |
+
mask_idx,
|
122 |
+
mask_whole_words,
|
123 |
+
shuffle,
|
124 |
+
seed,
|
125 |
+
args,
|
126 |
+
eos=None,
|
127 |
+
item_transform_func=None,
|
128 |
+
iid_noise_target=False,
|
129 |
+
uni_mask_idxs=None,
|
130 |
+
):
|
131 |
+
self.dataset = dataset
|
132 |
+
|
133 |
+
self.sizes = sizes
|
134 |
+
|
135 |
+
self.vocab = vocab
|
136 |
+
self.shuffle = shuffle
|
137 |
+
self.seed = seed
|
138 |
+
if iid_noise_target:
|
139 |
+
assert isinstance(uni_mask_idxs, torch.Tensor), "if use iid_noise_target, the uni_mask_idxs must be a tensor which contain the mask indexs"
|
140 |
+
self.iid_noise_target = iid_noise_target
|
141 |
+
self.uni_mask_idxs = uni_mask_idxs
|
142 |
+
self.mask_idx = mask_idx
|
143 |
+
self.mask_whole_word = mask_whole_words
|
144 |
+
self.mask_ratio = args.mask
|
145 |
+
self.random_ratio = args.mask_random
|
146 |
+
self.insert_ratio = args.insert
|
147 |
+
self.rotate_ratio = args.rotate
|
148 |
+
self.permute_sentence_ratio = args.permute_sentences
|
149 |
+
self.eos = eos if eos is not None else vocab.eos()
|
150 |
+
self.item_transform_func = item_transform_func
|
151 |
+
|
152 |
+
if args.bpe != "gpt2":
|
153 |
+
self.full_stop_index = self.vocab.eos()
|
154 |
+
else:
|
155 |
+
assert args.bpe == "gpt2"
|
156 |
+
self.full_stop_index = self.vocab.index("13")
|
157 |
+
|
158 |
+
self.replace_length = args.replace_length
|
159 |
+
if self.replace_length not in [-1, 0, 1]:
|
160 |
+
raise ValueError(f"invalid arg: replace_length={self.replace_length}")
|
161 |
+
if args.mask_length not in ["subword", "word", "span-poisson"]:
|
162 |
+
raise ValueError(f"invalid arg: mask-length={args.mask_length}")
|
163 |
+
if args.mask_length == "subword" and args.replace_length not in [0, 1]:
|
164 |
+
raise ValueError(f"if using subwords, use replace-length=1 or 0")
|
165 |
+
|
166 |
+
self.mask_span_distribution = None
|
167 |
+
if args.mask_length == "span-poisson":
|
168 |
+
_lambda = args.poisson_lambda
|
169 |
+
|
170 |
+
lambda_to_the_k = 1
|
171 |
+
e_to_the_minus_lambda = math.exp(-_lambda)
|
172 |
+
k_factorial = 1
|
173 |
+
ps = []
|
174 |
+
for k in range(0, 128):
|
175 |
+
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial)
|
176 |
+
lambda_to_the_k *= _lambda
|
177 |
+
k_factorial *= k + 1
|
178 |
+
if ps[-1] < 0.0000001:
|
179 |
+
break
|
180 |
+
ps = torch.FloatTensor(ps)
|
181 |
+
self.mask_span_distribution = torch.distributions.Categorical(ps)
|
182 |
+
|
183 |
+
self.epoch = 0
|
184 |
+
|
185 |
+
@property
|
186 |
+
def can_reuse_epoch_itr_across_epochs(self):
|
187 |
+
return True # only the noise changes, not item sizes
|
188 |
+
|
189 |
+
def set_epoch(self, epoch, **unused):
|
190 |
+
self.epoch = epoch
|
191 |
+
|
192 |
+
def __getitem__(self, index):
|
193 |
+
with data_utils.numpy_seed(self.seed, self.epoch, index):
|
194 |
+
tokens = self.dataset[index]
|
195 |
+
assert tokens[-1] == self.eos
|
196 |
+
source, target = tokens, tokens.clone()
|
197 |
+
|
198 |
+
if self.permute_sentence_ratio > 0.0:
|
199 |
+
source = self.permute_sentences(source, self.permute_sentence_ratio)
|
200 |
+
|
201 |
+
if self.mask_ratio > 0:
|
202 |
+
source, new_target = self.add_whole_word_mask(source, self.mask_ratio)
|
203 |
+
if new_target is not None:
|
204 |
+
target = new_target
|
205 |
+
|
206 |
+
if self.insert_ratio > 0:
|
207 |
+
source = self.add_insertion_noise(source, self.insert_ratio)
|
208 |
+
|
209 |
+
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio:
|
210 |
+
source = self.add_rolling_noise(source)
|
211 |
+
# there can additional changes to make:
|
212 |
+
if self.item_transform_func is not None:
|
213 |
+
source, target = self.item_transform_func(source, target)
|
214 |
+
|
215 |
+
assert (source >= 0).all()
|
216 |
+
assert (source[1:-1] >= 1).all()
|
217 |
+
assert (source <= len(self.vocab)).all()
|
218 |
+
assert source[0] == self.vocab.bos()
|
219 |
+
assert source[-1] == self.eos
|
220 |
+
return {
|
221 |
+
"id": index,
|
222 |
+
"source": source,
|
223 |
+
"target": target,
|
224 |
+
}
|
225 |
+
|
226 |
+
def __len__(self):
|
227 |
+
return len(self.dataset)
|
228 |
+
|
229 |
+
def permute_sentences(self, source, p=1.0):
|
230 |
+
full_stops = source == self.full_stop_index
|
231 |
+
# Pretend it ends with a full stop so last span is a sentence
|
232 |
+
full_stops[-2] = 1
|
233 |
+
|
234 |
+
# Tokens that are full stops, where the previous token is not
|
235 |
+
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2
|
236 |
+
result = source.clone()
|
237 |
+
|
238 |
+
num_sentences = sentence_ends.size(0)
|
239 |
+
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0)
|
240 |
+
substitutions = torch.randperm(num_sentences)[:num_to_permute]
|
241 |
+
ordering = torch.arange(0, num_sentences)
|
242 |
+
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)]
|
243 |
+
|
244 |
+
# Ignore <bos> at start
|
245 |
+
index = 1
|
246 |
+
for i in ordering:
|
247 |
+
sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]]
|
248 |
+
result[index : index + sentence.size(0)] = sentence
|
249 |
+
index += sentence.size(0)
|
250 |
+
return result
|
251 |
+
|
252 |
+
def word_starts(self, source):
|
253 |
+
if self.mask_whole_word is not None:
|
254 |
+
is_word_start = self.mask_whole_word.gather(0, source)
|
255 |
+
else:
|
256 |
+
is_word_start = torch.ones(source.size())
|
257 |
+
is_word_start[0] = 0
|
258 |
+
is_word_start[-1] = 0
|
259 |
+
return is_word_start
|
260 |
+
|
261 |
+
def add_whole_word_mask(self, source, p):
|
262 |
+
source_ori = source.clone()
|
263 |
+
is_word_start = self.word_starts(source)
|
264 |
+
num_to_mask = int(math.ceil(is_word_start.float().sum() * p))
|
265 |
+
num_inserts = 0
|
266 |
+
if num_to_mask == 0:
|
267 |
+
return source
|
268 |
+
|
269 |
+
if self.mask_span_distribution is not None:
|
270 |
+
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,))
|
271 |
+
|
272 |
+
# Make sure we have enough to mask
|
273 |
+
cum_length = torch.cumsum(lengths, 0)
|
274 |
+
while cum_length[-1] < num_to_mask:
|
275 |
+
lengths = torch.cat(
|
276 |
+
[
|
277 |
+
lengths,
|
278 |
+
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)),
|
279 |
+
],
|
280 |
+
dim=0,
|
281 |
+
)
|
282 |
+
cum_length = torch.cumsum(lengths, 0)
|
283 |
+
|
284 |
+
# Trim to masking budget
|
285 |
+
i = 0
|
286 |
+
while cum_length[i] < num_to_mask:
|
287 |
+
i += 1
|
288 |
+
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1])
|
289 |
+
num_to_mask = i + 1
|
290 |
+
lengths = lengths[:num_to_mask]
|
291 |
+
|
292 |
+
# Handle 0-length mask (inserts) separately
|
293 |
+
lengths = lengths[lengths > 0]
|
294 |
+
num_inserts = num_to_mask - lengths.size(0)
|
295 |
+
num_to_mask -= num_inserts
|
296 |
+
if num_to_mask == 0:
|
297 |
+
return self.add_insertion_noise(source, num_inserts / source.size(0))
|
298 |
+
|
299 |
+
assert (lengths > 0).all()
|
300 |
+
else:
|
301 |
+
lengths = torch.ones((num_to_mask,)).long()
|
302 |
+
assert is_word_start[-1] == 0
|
303 |
+
word_starts = is_word_start.nonzero(as_tuple=False)
|
304 |
+
indices = word_starts[
|
305 |
+
torch.randperm(word_starts.size(0))[:num_to_mask]
|
306 |
+
].squeeze(1)
|
307 |
+
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio
|
308 |
+
|
309 |
+
source_length = source.size(0)
|
310 |
+
assert source_length - 1 not in indices
|
311 |
+
to_keep = torch.ones(source_length, dtype=torch.bool)
|
312 |
+
is_word_start[
|
313 |
+
-1
|
314 |
+
] = 255 # acts as a long length, so spans don't go over the end of doc
|
315 |
+
if self.replace_length == 0:
|
316 |
+
to_keep[indices] = 0
|
317 |
+
else:
|
318 |
+
# keep index, but replace it with [MASK]
|
319 |
+
source[indices] = self.mask_idx
|
320 |
+
source[indices[mask_random]] = torch.randint(
|
321 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
322 |
+
)
|
323 |
+
|
324 |
+
if self.mask_span_distribution is not None:
|
325 |
+
assert len(lengths.size()) == 1
|
326 |
+
assert lengths.size() == indices.size()
|
327 |
+
lengths -= 1
|
328 |
+
while indices.size(0) > 0:
|
329 |
+
assert lengths.size() == indices.size()
|
330 |
+
lengths -= is_word_start[indices + 1].long()
|
331 |
+
uncompleted = lengths >= 0
|
332 |
+
indices = indices[uncompleted] + 1
|
333 |
+
mask_random = mask_random[uncompleted]
|
334 |
+
lengths = lengths[uncompleted]
|
335 |
+
if self.replace_length != -1:
|
336 |
+
# delete token
|
337 |
+
to_keep[indices] = 0
|
338 |
+
else:
|
339 |
+
# keep index, but replace it with [MASK]
|
340 |
+
source[indices] = self.mask_idx
|
341 |
+
source[indices[mask_random]] = torch.randint(
|
342 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
343 |
+
)
|
344 |
+
else:
|
345 |
+
# A bit faster when all lengths are 1
|
346 |
+
while indices.size(0) > 0:
|
347 |
+
uncompleted = is_word_start[indices + 1] == 0
|
348 |
+
indices = indices[uncompleted] + 1
|
349 |
+
mask_random = mask_random[uncompleted]
|
350 |
+
if self.replace_length != -1:
|
351 |
+
# delete token
|
352 |
+
to_keep[indices] = 0
|
353 |
+
else:
|
354 |
+
# keep index, but replace it with [MASK]
|
355 |
+
source[indices] = self.mask_idx
|
356 |
+
source[indices[mask_random]] = torch.randint(
|
357 |
+
1, len(self.vocab), size=(mask_random.sum(),)
|
358 |
+
)
|
359 |
+
|
360 |
+
assert source_length - 1 not in indices
|
361 |
+
|
362 |
+
if not self.iid_noise_target:
|
363 |
+
source = source[to_keep]
|
364 |
+
target = None
|
365 |
+
else:
|
366 |
+
## Prepare source
|
367 |
+
source_mask_idx = (source == self.mask_idx).nonzero().view(-1)
|
368 |
+
source[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)]
|
369 |
+
source = source[to_keep]
|
370 |
+
|
371 |
+
## Prepare target
|
372 |
+
to_keep[source_mask_idx] = 0
|
373 |
+
|
374 |
+
# source_mask_idx: from [a, b, c, ...] to [a, b + 1, c + 2, ...]
|
375 |
+
source_mask_idx = source_mask_idx + torch.arange(source_mask_idx.size(0))
|
376 |
+
# target: source_length + mask_length
|
377 |
+
target = source_ori.new_zeros(source_mask_idx.size(0) + source_ori.size(0))
|
378 |
+
# target: [0, 0, 0, X, 0, 0, Y, ....]
|
379 |
+
target[source_mask_idx] = self.uni_mask_idxs[:source_mask_idx.size(0)]
|
380 |
+
|
381 |
+
target_to_keep = to_keep.new_zeros(source_mask_idx.size(0) + source_ori.size(0))
|
382 |
+
|
383 |
+
# Copy original value to target and target_to_keep
|
384 |
+
target_to_keep[target == 0] = to_keep
|
385 |
+
target_to_keep[-1] = 0
|
386 |
+
target[target == 0] = source_ori
|
387 |
+
|
388 |
+
target = target[~target_to_keep]
|
389 |
+
|
390 |
+
if num_inserts > 0:
|
391 |
+
source = self.add_insertion_noise(source, num_inserts / source.size(0))
|
392 |
+
|
393 |
+
return source, target
|
394 |
+
|
395 |
+
def add_permuted_noise(self, tokens, p):
|
396 |
+
num_words = len(tokens)
|
397 |
+
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0)
|
398 |
+
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1
|
399 |
+
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]]
|
400 |
+
return tokens
|
401 |
+
|
402 |
+
def add_rolling_noise(self, tokens):
|
403 |
+
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1)
|
404 |
+
tokens = torch.cat(
|
405 |
+
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]),
|
406 |
+
dim=0,
|
407 |
+
)
|
408 |
+
return tokens
|
409 |
+
|
410 |
+
def add_insertion_noise(self, tokens, p):
|
411 |
+
if p == 0.0:
|
412 |
+
return tokens
|
413 |
+
|
414 |
+
num_tokens = len(tokens)
|
415 |
+
n = int(math.ceil(num_tokens * p))
|
416 |
+
|
417 |
+
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1
|
418 |
+
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool)
|
419 |
+
noise_mask[noise_indices] = 1
|
420 |
+
result = torch.LongTensor(n + len(tokens)).fill_(-1)
|
421 |
+
|
422 |
+
num_random = int(math.ceil(n * self.random_ratio))
|
423 |
+
result[noise_indices[num_random:]] = self.mask_idx
|
424 |
+
result[noise_indices[:num_random]] = torch.randint(
|
425 |
+
low=1, high=len(self.vocab), size=(num_random,)
|
426 |
+
)
|
427 |
+
|
428 |
+
result[~noise_mask] = tokens
|
429 |
+
|
430 |
+
assert (result >= 0).all()
|
431 |
+
return result
|
432 |
+
|
433 |
+
def collater(self, samples, pad_to_length=None):
|
434 |
+
"""Merge a list of samples to form a mini-batch.
|
435 |
+
Args:
|
436 |
+
samples (List[dict]): samples to collate
|
437 |
+
Returns:
|
438 |
+
dict: a mini-batch of data
|
439 |
+
"""
|
440 |
+
return collate(
|
441 |
+
samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length
|
442 |
+
)
|
443 |
+
|
444 |
+
def num_tokens(self, index):
|
445 |
+
"""Return the number of tokens in a sample. This value is used to
|
446 |
+
enforce ``--max-tokens`` during batching."""
|
447 |
+
return self.sizes[index]
|
448 |
+
|
449 |
+
def size(self, index):
|
450 |
+
"""Return an example's size as a float or tuple. This value is used when
|
451 |
+
filtering a dataset with ``--max-positions``."""
|
452 |
+
return self.sizes[index]
|
453 |
+
|
454 |
+
def ordered_indices(self):
|
455 |
+
"""Return an ordered list of indices. Batches will be constructed based
|
456 |
+
on this order."""
|
457 |
+
if self.shuffle:
|
458 |
+
indices = np.random.permutation(len(self))
|
459 |
+
else:
|
460 |
+
indices = np.arange(len(self))
|
461 |
+
return indices[np.argsort(self.sizes[indices], kind="mergesort")]
|
462 |
+
|
463 |
+
def prefetch(self, indices):
|
464 |
+
self.src.prefetch(indices)
|
465 |
+
self.tgt.prefetch(indices)
|
466 |
+
|
467 |
+
@property
|
468 |
+
def supports_prefetch(self):
|
469 |
+
return (
|
470 |
+
hasattr(self.src, "supports_prefetch")
|
471 |
+
and self.src.supports_prefetch
|
472 |
+
and hasattr(self.tgt, "supports_prefetch")
|
473 |
+
and self.tgt.supports_prefetch
|
474 |
+
)
|
artst/data/text_to_speech_dataset.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import itertools
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
from typing import Any, List, Optional
|
12 |
+
import mmap
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import librosa
|
19 |
+
from fairseq.data.audio.speech_to_text_dataset import get_features_or_waveform
|
20 |
+
from fairseq.data import data_utils, Dictionary
|
21 |
+
from fairseq.data.fairseq_dataset import FairseqDataset
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
def _collate_frames(
|
27 |
+
frames: List[torch.Tensor], is_audio_input: bool = False
|
28 |
+
):
|
29 |
+
"""
|
30 |
+
Convert a list of 2D frames into a padded 3D tensor
|
31 |
+
Args:
|
32 |
+
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
33 |
+
length of i-th frame and f_dim is static dimension of features
|
34 |
+
Returns:
|
35 |
+
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
36 |
+
"""
|
37 |
+
max_len = max(frame.size(0) for frame in frames)
|
38 |
+
if is_audio_input:
|
39 |
+
out = frames[0].new_zeros((len(frames), max_len))
|
40 |
+
else:
|
41 |
+
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
42 |
+
for i, v in enumerate(frames):
|
43 |
+
out[i, : v.size(0)] = v
|
44 |
+
return out
|
45 |
+
|
46 |
+
def load_audio(manifest_path, max_keep, min_keep):
|
47 |
+
n_long, n_short = 0, 0
|
48 |
+
names, inds, sizes, spk_embeds = [], [], [], []
|
49 |
+
with open(manifest_path) as f:
|
50 |
+
root = f.readline().strip()
|
51 |
+
for ind, line in enumerate(f):
|
52 |
+
items = line.strip().split("\t")
|
53 |
+
assert len(items) == 3, line
|
54 |
+
sz = int(items[1])
|
55 |
+
if min_keep is not None and sz < min_keep:
|
56 |
+
n_short += 1
|
57 |
+
elif max_keep is not None and sz > max_keep:
|
58 |
+
n_long += 1
|
59 |
+
else:
|
60 |
+
names.append(items[0])
|
61 |
+
spk_embeds.append(items[2])
|
62 |
+
inds.append(ind)
|
63 |
+
sizes.append(sz)
|
64 |
+
tot = ind + 1
|
65 |
+
logger.info(
|
66 |
+
(
|
67 |
+
f"max_keep={max_keep}, min_keep={min_keep}, "
|
68 |
+
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
|
69 |
+
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
|
70 |
+
)
|
71 |
+
)
|
72 |
+
return root, names, inds, tot, sizes, spk_embeds
|
73 |
+
|
74 |
+
|
75 |
+
def load_label(label_path, inds, tot):
|
76 |
+
with open(label_path) as f:
|
77 |
+
labels = [line.rstrip() for line in f]
|
78 |
+
assert (
|
79 |
+
len(labels) == tot
|
80 |
+
), f"number of labels does not match ({len(labels)} != {tot})"
|
81 |
+
labels = [labels[i] for i in inds]
|
82 |
+
return labels
|
83 |
+
|
84 |
+
|
85 |
+
def load_label_offset(label_path, inds, tot):
|
86 |
+
with open(label_path, encoding='utf-8') as f:
|
87 |
+
code_lengths = [len(line.encode("utf-8")) for line in f] #changed as in speech_to_text_dataset.py
|
88 |
+
assert (
|
89 |
+
len(code_lengths) == tot
|
90 |
+
), f"number of labels does not match ({len(code_lengths)} != {tot})"
|
91 |
+
offsets = list(itertools.accumulate([0] + code_lengths))
|
92 |
+
offsets = [(offsets[i], offsets[i + 1]) for i in inds]
|
93 |
+
return offsets
|
94 |
+
|
95 |
+
|
96 |
+
def logmelfilterbank(
|
97 |
+
audio,
|
98 |
+
sampling_rate,
|
99 |
+
fft_size=1024,
|
100 |
+
hop_size=256,
|
101 |
+
win_length=None,
|
102 |
+
window="hann",
|
103 |
+
num_mels=80,
|
104 |
+
fmin=80,
|
105 |
+
fmax=7600,
|
106 |
+
eps=1e-10,
|
107 |
+
):
|
108 |
+
"""Compute log-Mel filterbank feature.
|
109 |
+
(https://github.com/kan-bayashi/ParallelWaveGAN/blob/master/parallel_wavegan/bin/preprocess.py)
|
110 |
+
|
111 |
+
Args:
|
112 |
+
audio (ndarray): Audio signal (T,).
|
113 |
+
sampling_rate (int): Sampling rate.
|
114 |
+
fft_size (int): FFT size.
|
115 |
+
hop_size (int): Hop size.
|
116 |
+
win_length (int): Window length. If set to None, it will be the same as fft_size.
|
117 |
+
window (str): Window function type.
|
118 |
+
num_mels (int): Number of mel basis.
|
119 |
+
fmin (int): Minimum frequency in mel basis calculation.
|
120 |
+
fmax (int): Maximum frequency in mel basis calculation.
|
121 |
+
eps (float): Epsilon value to avoid inf in log calculation.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
ndarray: Log Mel filterbank feature (#frames, num_mels).
|
125 |
+
|
126 |
+
"""
|
127 |
+
# get amplitude spectrogram
|
128 |
+
x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
|
129 |
+
win_length=win_length, window=window, pad_mode="reflect")
|
130 |
+
spc = np.abs(x_stft).T # (#frames, #bins)
|
131 |
+
|
132 |
+
# get mel basis
|
133 |
+
fmin = 0 if fmin is None else fmin
|
134 |
+
fmax = sampling_rate / 2 if fmax is None else fmax
|
135 |
+
mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
136 |
+
|
137 |
+
return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
class TextToSpeechDataset(FairseqDataset):
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
manifest_path: str,
|
145 |
+
sample_rate: float,
|
146 |
+
label_paths: List[str],
|
147 |
+
label_processors: Optional[List[Any]] = None,
|
148 |
+
max_keep_sample_size: Optional[int] = None,
|
149 |
+
min_keep_sample_size: Optional[int] = None,
|
150 |
+
shuffle: bool = True,
|
151 |
+
normalize: bool = False,
|
152 |
+
store_labels: bool = True,
|
153 |
+
src_dict: Optional[Dictionary] = None,
|
154 |
+
tokenizer = None,
|
155 |
+
reduction_factor: int = 1,
|
156 |
+
inference: bool = False,
|
157 |
+
):
|
158 |
+
|
159 |
+
self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.spk_embeds = load_audio(
|
160 |
+
manifest_path, max_keep_sample_size, min_keep_sample_size
|
161 |
+
)
|
162 |
+
self.inference = inference
|
163 |
+
|
164 |
+
self.sample_rate = sample_rate
|
165 |
+
self.shuffle = shuffle
|
166 |
+
self.src_dict = src_dict
|
167 |
+
self.tokenizer = tokenizer
|
168 |
+
|
169 |
+
self.num_labels = len(label_paths)
|
170 |
+
self.label_processors = label_processors
|
171 |
+
self.store_labels = store_labels
|
172 |
+
if store_labels:
|
173 |
+
self.label_list = [load_label(p, inds, tot) for p in label_paths]
|
174 |
+
else:
|
175 |
+
self.label_paths = label_paths
|
176 |
+
self.label_offsets_list = [
|
177 |
+
load_label_offset(p, inds, tot) for p in label_paths
|
178 |
+
]
|
179 |
+
assert label_processors is None or len(label_processors) == self.num_labels
|
180 |
+
|
181 |
+
self.normalize = normalize
|
182 |
+
self.reduction_factor = reduction_factor
|
183 |
+
logger.info(
|
184 |
+
f"reduction_factor={reduction_factor}, normalize={normalize}"
|
185 |
+
)
|
186 |
+
|
187 |
+
def get_audio(self, index):
|
188 |
+
import soundfile as sf
|
189 |
+
|
190 |
+
wav_path = os.path.join(self.audio_root, self.audio_names[index])
|
191 |
+
wav, cur_sample_rate = sf.read(wav_path)
|
192 |
+
wav = torch.from_numpy(wav).float()
|
193 |
+
fbank = logmelfilterbank(
|
194 |
+
wav.view(-1).cpu().numpy(), 16000
|
195 |
+
)
|
196 |
+
fbank = torch.from_numpy(fbank).float()
|
197 |
+
wav = self.postprocess(wav, cur_sample_rate)
|
198 |
+
return wav, fbank
|
199 |
+
|
200 |
+
def get_label(self, index, label_idx):
|
201 |
+
if self.store_labels:
|
202 |
+
label = self.label_list[label_idx][index]
|
203 |
+
else:
|
204 |
+
# with open(self.label_paths[label_idx]) as f:
|
205 |
+
# offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
206 |
+
# f.seek(offset_s)
|
207 |
+
# label = f.read(offset_e - offset_s)
|
208 |
+
|
209 |
+
# Hawau:
|
210 |
+
# mmap method
|
211 |
+
with open(self.label_paths[label_idx], encoding='utf-8') as f:
|
212 |
+
with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mm:
|
213 |
+
offset_s, offset_e = self.label_offsets_list[label_idx][index]
|
214 |
+
label = mm[offset_s:offset_e].decode("utf-8")
|
215 |
+
|
216 |
+
|
217 |
+
if self.tokenizer is not None:
|
218 |
+
label = self.tokenizer.encode(label)
|
219 |
+
|
220 |
+
if self.label_processors is not None:
|
221 |
+
label = self.label_processors[label_idx](label)
|
222 |
+
return label
|
223 |
+
|
224 |
+
def get_labels(self, index):
|
225 |
+
return [self.get_label(index, i) for i in range(self.num_labels)]
|
226 |
+
|
227 |
+
def __getitem__(self, index):
|
228 |
+
wav, fbank = self.get_audio(index)
|
229 |
+
labels = self.get_labels(index)
|
230 |
+
spkembs = get_features_or_waveform(
|
231 |
+
os.path.join(self.audio_root, self.spk_embeds[index])
|
232 |
+
)
|
233 |
+
spkembs = torch.from_numpy(spkembs).float()
|
234 |
+
|
235 |
+
return {"id": index, "source": labels, "target": fbank, "spkembs": spkembs, "audio_name": self.audio_names[index]}
|
236 |
+
|
237 |
+
|
238 |
+
def __len__(self):
|
239 |
+
return len(self.wav_sizes)
|
240 |
+
|
241 |
+
def collater(self, samples):
|
242 |
+
samples = [s for s in samples if s["source"] is not None]
|
243 |
+
if len(samples) == 0:
|
244 |
+
return {}
|
245 |
+
|
246 |
+
fbanks = [s["target"] for s in samples]
|
247 |
+
fbank_sizes = [len(s) for s in fbanks]
|
248 |
+
|
249 |
+
collated_fbanks = _collate_frames(fbanks)
|
250 |
+
collated_fbanks_size = torch.tensor(fbank_sizes, dtype=torch.long)
|
251 |
+
|
252 |
+
# thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim)
|
253 |
+
if self.reduction_factor > 1:
|
254 |
+
collated_fbanks_in = collated_fbanks[:, self.reduction_factor - 1 :: self.reduction_factor]
|
255 |
+
collated_fbanks_size_in = collated_fbanks_size.new([torch.div(olen, self.reduction_factor, rounding_mode='floor') for olen in collated_fbanks_size])
|
256 |
+
else:
|
257 |
+
collated_fbanks_in, collated_fbanks_size_in = collated_fbanks, collated_fbanks_size
|
258 |
+
|
259 |
+
prev_output_tokens = torch.cat(
|
260 |
+
[collated_fbanks_in.new_zeros((collated_fbanks_in.shape[0], 1, collated_fbanks_in.shape[2])), collated_fbanks_in[:, :-1]], dim=1
|
261 |
+
)
|
262 |
+
|
263 |
+
# make labels for stop prediction
|
264 |
+
labels = collated_fbanks.new_zeros(collated_fbanks.size(0), collated_fbanks.size(1))
|
265 |
+
for i, l in enumerate(fbank_sizes):
|
266 |
+
labels[i, l - 1 :] = 1.0
|
267 |
+
|
268 |
+
spkembs = _collate_frames([s["spkembs"] for s in samples], is_audio_input=True)
|
269 |
+
|
270 |
+
sources_by_label = [
|
271 |
+
[s["source"][i] for s in samples] for i in range(self.num_labels)
|
272 |
+
]
|
273 |
+
sources_list, lengths_list, ntokens_list = self.collater_label(sources_by_label)
|
274 |
+
|
275 |
+
net_input = {
|
276 |
+
"src_tokens": sources_list[0],
|
277 |
+
"src_lengths": lengths_list[0],
|
278 |
+
"prev_output_tokens": prev_output_tokens,
|
279 |
+
"tgt_lengths": collated_fbanks_size_in,
|
280 |
+
"spkembs": spkembs,
|
281 |
+
"task_name": "t2s",
|
282 |
+
}
|
283 |
+
batch = {
|
284 |
+
"id": torch.LongTensor([s["id"] for s in samples]),
|
285 |
+
"name": [s["audio_name"] for s in samples],
|
286 |
+
"net_input": net_input,
|
287 |
+
"labels": labels,
|
288 |
+
"dec_target": collated_fbanks,
|
289 |
+
"dec_target_lengths": collated_fbanks_size,
|
290 |
+
"src_lengths": lengths_list[0],
|
291 |
+
"task_name": "t2s",
|
292 |
+
"ntokens": ntokens_list[0],
|
293 |
+
"target": collated_fbanks,
|
294 |
+
}
|
295 |
+
|
296 |
+
return batch
|
297 |
+
|
298 |
+
def collater_seq_label(self, targets, pad):
|
299 |
+
lengths = torch.LongTensor([len(t) for t in targets])
|
300 |
+
ntokens = lengths.sum().item()
|
301 |
+
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
|
302 |
+
return targets, lengths, ntokens
|
303 |
+
|
304 |
+
def collater_label(self, targets_by_label):
|
305 |
+
targets_list, lengths_list, ntokens_list = [], [], []
|
306 |
+
itr = zip(targets_by_label, [self.src_dict.pad()])
|
307 |
+
for targets, pad in itr:
|
308 |
+
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
|
309 |
+
targets_list.append(targets)
|
310 |
+
lengths_list.append(lengths)
|
311 |
+
ntokens_list.append(ntokens)
|
312 |
+
return targets_list, lengths_list, ntokens_list
|
313 |
+
|
314 |
+
def num_tokens(self, index):
|
315 |
+
return self.size(index)
|
316 |
+
|
317 |
+
def size(self, index):
|
318 |
+
return self.wav_sizes[index]
|
319 |
+
|
320 |
+
@property
|
321 |
+
def sizes(self):
|
322 |
+
return np.array(self.wav_sizes)
|
323 |
+
|
324 |
+
def ordered_indices(self):
|
325 |
+
if self.shuffle:
|
326 |
+
order = [np.random.permutation(len(self))]
|
327 |
+
else:
|
328 |
+
order = [np.arange(len(self))]
|
329 |
+
|
330 |
+
order.append(self.wav_sizes)
|
331 |
+
return np.lexsort(order)[::-1]
|
332 |
+
|
333 |
+
def postprocess(self, wav, cur_sample_rate):
|
334 |
+
if wav.dim() == 2:
|
335 |
+
wav = wav.mean(-1)
|
336 |
+
assert wav.dim() == 1, wav.dim()
|
337 |
+
|
338 |
+
if cur_sample_rate != self.sample_rate:
|
339 |
+
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
|
340 |
+
|
341 |
+
if self.normalize:
|
342 |
+
with torch.no_grad():
|
343 |
+
wav = F.layer_norm(wav, wav.shape)
|
344 |
+
return wav
|
artst/models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .artst import * # noqa
|
2 |
+
from .t5_transformer_lm import * # noqa
|
artst/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (193 Bytes). View file
|
|
artst/models/__pycache__/artst.cpython-38.pyc
ADDED
Binary file (37.2 kB). View file
|
|
artst/models/__pycache__/speecht5.cpython-38.pyc
ADDED
Binary file (37 kB). View file
|
|
artst/models/__pycache__/t5_transformer_lm.cpython-38.pyc
ADDED
Binary file (733 Bytes). View file
|
|
artst/models/artst.py
ADDED
@@ -0,0 +1,1448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
|
3 |
+
# Github source: https://github.com/mbzuai-nlp/ArTST
|
4 |
+
# Based on speecht5, fairseq and espnet code bases
|
5 |
+
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import logging
|
9 |
+
from ast import literal_eval
|
10 |
+
from typing import Dict, List, Optional, Tuple
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from fairseq import utils
|
15 |
+
from fairseq.models import (
|
16 |
+
FairseqEncoderDecoderModel,
|
17 |
+
FairseqIncrementalDecoder,
|
18 |
+
register_model,
|
19 |
+
register_model_architecture,
|
20 |
+
)
|
21 |
+
from .modules.text_encoder_prenet import TextEncoderPrenet
|
22 |
+
from .modules.text_decoder_prenet import TextDecoderPrenet
|
23 |
+
from .modules.text_decoder_postnet import TextDecoderPostnet
|
24 |
+
from .modules.speech_encoder_prenet import SpeechEncoderPrenet
|
25 |
+
from .modules.speech_encoder_postnet import SpeechEncoderPostnet
|
26 |
+
from .modules.speech_decoder_prenet import SpeechDecoderPrenet
|
27 |
+
from .modules.speech_decoder_postnet import SpeechDecoderPostnet
|
28 |
+
from .modules.speaker_decoder_postnet import SpeakerDecoderPostnet
|
29 |
+
from .modules.encoder import TransformerEncoder
|
30 |
+
from .modules.decoder import TransformerDecoder
|
31 |
+
from fairseq.modules.transformer_sentence_encoder import init_bert_params
|
32 |
+
from fairseq.models.transformer import Embedding
|
33 |
+
from fairseq.modules import (
|
34 |
+
GumbelVectorQuantizer,
|
35 |
+
)
|
36 |
+
from torch import Tensor
|
37 |
+
|
38 |
+
|
39 |
+
logger = logging.getLogger(__name__)
|
40 |
+
|
41 |
+
DEFAULT_MAX_TEXT_POSITIONS = 450
|
42 |
+
DEFAULT_MAX_SPEECH_POSITIONS = 4000
|
43 |
+
|
44 |
+
|
45 |
+
@register_model("artst_transformer")
|
46 |
+
class ArTSTTransformerModel(FairseqEncoderDecoderModel):
|
47 |
+
"""Adapted Transformer model (https://arxiv.org/abs/1706.03762) for
|
48 |
+
speech-to-text tasks. The Transformer encoder/decoder remains the same.
|
49 |
+
A trainable input subsampler is prepended to the Transformer encoder to
|
50 |
+
project inputs into the encoder dimension as well as downsample input
|
51 |
+
sequence for computational efficiency."""
|
52 |
+
|
53 |
+
def __init__(
|
54 |
+
self,
|
55 |
+
args,
|
56 |
+
encoder, decoder,
|
57 |
+
text_encoder_prenet, speech_encoder_prenet,
|
58 |
+
text_decoder_prenet, speech_decoder_prenet,
|
59 |
+
text_decoder_postnet, speech_decoder_postnet,
|
60 |
+
speaker_decoder_postnet, speech_encoder_postnet,
|
61 |
+
):
|
62 |
+
super().__init__(encoder, decoder)
|
63 |
+
|
64 |
+
self.encoder = encoder
|
65 |
+
self.decoder = decoder
|
66 |
+
|
67 |
+
self.text_encoder_prenet = text_encoder_prenet
|
68 |
+
self.speech_encoder_prenet = speech_encoder_prenet
|
69 |
+
|
70 |
+
self.text_decoder_prenet = text_decoder_prenet
|
71 |
+
self.speech_decoder_prenet = speech_decoder_prenet
|
72 |
+
|
73 |
+
self.text_decoder_postnet = text_decoder_postnet
|
74 |
+
self.speech_decoder_postnet = speech_decoder_postnet
|
75 |
+
self.speaker_decoder_postnet = speaker_decoder_postnet
|
76 |
+
|
77 |
+
self.hubert_layer = speech_encoder_postnet
|
78 |
+
|
79 |
+
self.reduction_factor = args.reduction_factor
|
80 |
+
self.spk_embed_dim = args.spk_embed_dim
|
81 |
+
|
82 |
+
# define projection layer
|
83 |
+
self.spk_embed_integration_type = args.spk_embed_integration_type
|
84 |
+
if self.spk_embed_dim is not None and self.spk_embed_integration_type != 'pre':
|
85 |
+
if self.spk_embed_integration_type == "add":
|
86 |
+
self.projection = torch.nn.Linear(self.spk_embed_dim, args.decoder_embed_dim)
|
87 |
+
else:
|
88 |
+
self.projection = torch.nn.Linear(
|
89 |
+
args.decoder_embed_dim + self.spk_embed_dim, args.decoder_embed_dim
|
90 |
+
)
|
91 |
+
|
92 |
+
# Hawau: here we can add language embedding integration
|
93 |
+
|
94 |
+
self.use_codebook = args.use_codebook
|
95 |
+
self.codebook_prob = getattr(args, "codebook_prob", 0.5) # args.codebook_prob
|
96 |
+
if self.use_codebook:
|
97 |
+
vq_dim = args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim
|
98 |
+
self.quantizer = GumbelVectorQuantizer(
|
99 |
+
dim=args.encoder_embed_dim,
|
100 |
+
num_vars=args.latent_vars,
|
101 |
+
temp=args.latent_temp,
|
102 |
+
groups=args.latent_groups,
|
103 |
+
combine_groups=False,
|
104 |
+
vq_dim=vq_dim,
|
105 |
+
time_first=True,
|
106 |
+
weight_proj_depth=args.quantizer_depth,
|
107 |
+
weight_proj_factor=args.quantizer_factor,
|
108 |
+
)
|
109 |
+
|
110 |
+
self.num_updates = 0
|
111 |
+
|
112 |
+
# # Follow BERT's random weight initialization (for BART)
|
113 |
+
if args.bert_init:
|
114 |
+
self.apply(init_bert_params)
|
115 |
+
self.args = args
|
116 |
+
self.prune_modules(args.modules_filter)
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def add_args(parser):
|
120 |
+
"""Add model-specific arguments to the parser."""
|
121 |
+
# Transformer
|
122 |
+
parser.add_argument(
|
123 |
+
"--activation-fn",
|
124 |
+
type=str,
|
125 |
+
choices=utils.get_available_activation_fns(),
|
126 |
+
help="activation function to use",
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--dropout", type=float, metavar="D", help="dropout probability"
|
130 |
+
)
|
131 |
+
parser.add_argument(
|
132 |
+
"--attention-dropout",
|
133 |
+
type=float,
|
134 |
+
metavar="D",
|
135 |
+
help="dropout probability for attention weights",
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--activation-dropout",
|
139 |
+
"--relu-dropout",
|
140 |
+
type=float,
|
141 |
+
metavar="D",
|
142 |
+
help="dropout probability after activation in FFN.",
|
143 |
+
)
|
144 |
+
parser.add_argument(
|
145 |
+
"--encoder-embed-dim",
|
146 |
+
type=int,
|
147 |
+
metavar="N",
|
148 |
+
help="encoder embedding dimension",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--encoder-ffn-embed-dim",
|
152 |
+
type=int,
|
153 |
+
metavar="N",
|
154 |
+
help="encoder embedding dimension for FFN",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--encoder-layers", type=int, metavar="N", help="num encoder layers"
|
158 |
+
)
|
159 |
+
parser.add_argument(
|
160 |
+
"--encoder-attention-heads",
|
161 |
+
type=int,
|
162 |
+
metavar="N",
|
163 |
+
help="num encoder attention heads",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--encoder-normalize-before",
|
167 |
+
action="store_true",
|
168 |
+
help="apply layernorm before each encoder block",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--decoder-normalize-before",
|
172 |
+
action="store_true",
|
173 |
+
help="apply layernorm before each decoder block",
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--decoder-embed-dim",
|
177 |
+
type=int,
|
178 |
+
metavar="N",
|
179 |
+
help="decoder embedding dimension",
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--decoder-ffn-embed-dim",
|
183 |
+
type=int,
|
184 |
+
metavar="N",
|
185 |
+
help="decoder embedding dimension for FFN",
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--decoder-layers", type=int, metavar="N", help="num decoder layers"
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--decoder-attention-heads",
|
192 |
+
type=int,
|
193 |
+
metavar="N",
|
194 |
+
help="num decoder attention heads",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--reduction-factor",
|
198 |
+
type=int,
|
199 |
+
help="reduction factor for decoder",
|
200 |
+
)
|
201 |
+
parser.add_argument(
|
202 |
+
"--spk-embed-dim",
|
203 |
+
type=int,
|
204 |
+
help="speaker embedding dimension",
|
205 |
+
)
|
206 |
+
parser.add_argument(
|
207 |
+
"--layernorm-embedding",
|
208 |
+
action="store_true",
|
209 |
+
help="add layernorm to embedding",
|
210 |
+
)
|
211 |
+
parser.add_argument(
|
212 |
+
"--load-pretrained-encoder-from",
|
213 |
+
type=str,
|
214 |
+
metavar="STR",
|
215 |
+
help="model to take encoder weights from (for initialization)",
|
216 |
+
)
|
217 |
+
parser.add_argument(
|
218 |
+
'--freeze-encoder-updates',
|
219 |
+
type=int,
|
220 |
+
help='number of steps to freeze encoder before finetune'
|
221 |
+
)
|
222 |
+
parser.add_argument(
|
223 |
+
'--freeze-decoder-updates',
|
224 |
+
type=int,
|
225 |
+
help='number of steps to freeze decoder before finetune'
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
'--no-freeze-encoder-layer',
|
229 |
+
type=str,
|
230 |
+
help='which encoder layer not freeze during finetune'
|
231 |
+
)
|
232 |
+
parser.add_argument(
|
233 |
+
"--share-input-output-embed",
|
234 |
+
action="store_true",
|
235 |
+
help="share decoder input and output embeddings",
|
236 |
+
)
|
237 |
+
parser.add_argument(
|
238 |
+
"--share-ctc-embed",
|
239 |
+
action="store_true",
|
240 |
+
help="share ctc embed and decoder embed",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--encoder-sliding-window-attn",
|
244 |
+
default=None,
|
245 |
+
type=int,
|
246 |
+
help="If not None but a even number, set sliding window attention to encoder's attn_mask, e.g., 4, 10, and 20",
|
247 |
+
)
|
248 |
+
|
249 |
+
# Convolutional subsampler
|
250 |
+
parser.add_argument(
|
251 |
+
"--encoder-speech-prenet",
|
252 |
+
default="conv",
|
253 |
+
type=str,
|
254 |
+
choices=["conv", "linear"],
|
255 |
+
help="The type of encoder speech prenet, e.g., conv or linear."
|
256 |
+
)
|
257 |
+
parser.add_argument(
|
258 |
+
"--conv-kernel-sizes",
|
259 |
+
default="5,5",
|
260 |
+
type=str,
|
261 |
+
help="The layer of convolution of encoder speech prenet."
|
262 |
+
)
|
263 |
+
parser.add_argument(
|
264 |
+
"--conv-channels",
|
265 |
+
default=1024,
|
266 |
+
type=int,
|
267 |
+
help="The channels of encoder speech prenet."
|
268 |
+
)
|
269 |
+
parser.add_argument(
|
270 |
+
"--subsample-stride",
|
271 |
+
default="2,2",
|
272 |
+
type=str,
|
273 |
+
help="The subsample stride for conv1dsubsample."
|
274 |
+
)
|
275 |
+
parser.add_argument(
|
276 |
+
"--spk-embed-integration-type",
|
277 |
+
type=str,
|
278 |
+
choices=["pre", "add"],
|
279 |
+
help="speaker embedding integration type"
|
280 |
+
)
|
281 |
+
parser.add_argument(
|
282 |
+
"--dprenet-dropout-rate",
|
283 |
+
default=0.5,
|
284 |
+
type=float,
|
285 |
+
help="The dropout rate of decoder speech prenet."
|
286 |
+
)
|
287 |
+
|
288 |
+
## SE
|
289 |
+
parser.add_argument(
|
290 |
+
"--se-predict",
|
291 |
+
default=None,
|
292 |
+
choices=["masking", "target", "delta"],
|
293 |
+
help="If set, source speech inputs decoder to predict the masking/target/delta of corresponding inputs."
|
294 |
+
+ "masking is [0, 1], target is predicted output, delta is difference between inputs and outputs",
|
295 |
+
)
|
296 |
+
parser.add_argument(
|
297 |
+
"--se-decoder-input",
|
298 |
+
type=str,
|
299 |
+
default="previous_target",
|
300 |
+
choices=["previous_target", "source"],
|
301 |
+
)
|
302 |
+
|
303 |
+
## SID
|
304 |
+
parser.add_argument(
|
305 |
+
"--modules-filter",
|
306 |
+
default=None,
|
307 |
+
type=str,
|
308 |
+
help="Remove unused modules for, e.g., SID.",
|
309 |
+
)
|
310 |
+
parser.add_argument(
|
311 |
+
"--sid-pad-prenet",
|
312 |
+
action="store_true",
|
313 |
+
help="If set, the size of text dictionary is as small as for <pad> token.",
|
314 |
+
)
|
315 |
+
parser.add_argument(
|
316 |
+
"--encoder-attn-branch",
|
317 |
+
type=str,
|
318 |
+
default="identity,full",
|
319 |
+
help="encoder attention branch sliding window, e.g., 'identity,0,2,4,full'",
|
320 |
+
)
|
321 |
+
parser.add_argument(
|
322 |
+
"--encoder-block-branch",
|
323 |
+
type=str,
|
324 |
+
help="average the output of encoder, e.g., '4,5,6'",
|
325 |
+
)
|
326 |
+
parser.add_argument(
|
327 |
+
"--sid-encoder-cls",
|
328 |
+
default=None,
|
329 |
+
choices=["encoder"],
|
330 |
+
help="If set, add cls vector to the encoder input, e.g., constant vector.",
|
331 |
+
)
|
332 |
+
parser.add_argument(
|
333 |
+
"--sid-shuffle-encoder-input",
|
334 |
+
action="store_true",
|
335 |
+
help="If set, shuffle encoder input in time.",
|
336 |
+
)
|
337 |
+
parser.add_argument(
|
338 |
+
"--sid-decoder-speaker",
|
339 |
+
action="store_true",
|
340 |
+
help="If set, apply speaker decoder as transformer decoder.",
|
341 |
+
)
|
342 |
+
parser.add_argument(
|
343 |
+
"--sid-decoder-attn-dim",
|
344 |
+
default=128,
|
345 |
+
type=int,
|
346 |
+
help="Attention dimension in attensive statistics pooling of speaker decoder.",
|
347 |
+
)
|
348 |
+
parser.add_argument(
|
349 |
+
"--sid-t5-postnet",
|
350 |
+
action="store_true",
|
351 |
+
help="If set, apply TextDecoderPostnet as speaker classification.",
|
352 |
+
)
|
353 |
+
parser.add_argument(
|
354 |
+
"--sid-embed-dim",
|
355 |
+
default=128,
|
356 |
+
type=int,
|
357 |
+
help="Embedding dimension in speaker postnet for speaker identification if embed postnet.",
|
358 |
+
)
|
359 |
+
parser.add_argument(
|
360 |
+
"--sid-pooling-layer",
|
361 |
+
default="decoder",
|
362 |
+
type=str,
|
363 |
+
choices=["decoder-las", "decoder", "encoder", "encoder-cls", "encoder-speaker"],
|
364 |
+
help="The output of decoder or encoder uses as SID pooling layer over temporal dimension.",
|
365 |
+
)
|
366 |
+
parser.add_argument(
|
367 |
+
"--sid-no-pooling-bn",
|
368 |
+
action="store_true",
|
369 |
+
help="If set, not attention batchnorm.",
|
370 |
+
)
|
371 |
+
parser.add_argument(
|
372 |
+
"--sid-no-embed-postnet",
|
373 |
+
action="store_true",
|
374 |
+
help="If set, no layer between decoder output and classification layer.",
|
375 |
+
)
|
376 |
+
parser.add_argument(
|
377 |
+
"--sid-normalize-postnet",
|
378 |
+
action="store_true",
|
379 |
+
help="If set, normalize input and weight in postnet/classifier.",
|
380 |
+
)
|
381 |
+
parser.add_argument(
|
382 |
+
"--sid-softmax-type",
|
383 |
+
default="softmax",
|
384 |
+
choices=["softmax", "amsoftmax", "aamsoftmax"],
|
385 |
+
help="If using amsoftmax or aamsoftmax, the target should be given.",
|
386 |
+
)
|
387 |
+
parser.add_argument(
|
388 |
+
"--softmax-scale",
|
389 |
+
default=1.0,
|
390 |
+
type=float,
|
391 |
+
help="Scale for AMSoftmax or AAMSoftmax.",
|
392 |
+
)
|
393 |
+
parser.add_argument(
|
394 |
+
"--softmax-margin",
|
395 |
+
default=0.0,
|
396 |
+
type=float,
|
397 |
+
help="Margin for AMSoftmax or AAMSoftmax.",
|
398 |
+
)
|
399 |
+
parser.add_argument(
|
400 |
+
"--softmax-easy-margin",
|
401 |
+
action="store_true",
|
402 |
+
help="Enable easy margin for AAMSoftmax.",
|
403 |
+
)
|
404 |
+
parser.add_argument(
|
405 |
+
"--encoder-layerdrop",
|
406 |
+
type=float,
|
407 |
+
metavar="D",
|
408 |
+
help="LayerDrop probability for encoder",
|
409 |
+
)
|
410 |
+
parser.add_argument(
|
411 |
+
"--decoder-layerdrop",
|
412 |
+
type=float,
|
413 |
+
metavar="D",
|
414 |
+
help="LayerDrop probability for decoder",
|
415 |
+
)
|
416 |
+
|
417 |
+
## Hubert
|
418 |
+
parser.add_argument(
|
419 |
+
'--feature-grad-mult',
|
420 |
+
type=float,
|
421 |
+
help='multiply feature extractor var grads by this'
|
422 |
+
)
|
423 |
+
parser.add_argument(
|
424 |
+
'--logit-temp',
|
425 |
+
type=float,
|
426 |
+
help='temperature to divide logits by'
|
427 |
+
)
|
428 |
+
parser.add_argument(
|
429 |
+
'--final-dim',
|
430 |
+
type=int,
|
431 |
+
help="project final representations and targets to this many "
|
432 |
+
"dimensions. set to encoder_embed_dim is <= 0"
|
433 |
+
)
|
434 |
+
|
435 |
+
# mask
|
436 |
+
parser.add_argument(
|
437 |
+
'--hubert-mask-length',
|
438 |
+
type=int,
|
439 |
+
help='mask length'
|
440 |
+
)
|
441 |
+
parser.add_argument(
|
442 |
+
'--mask-prob',
|
443 |
+
type=float,
|
444 |
+
help='probability of replacing a token with mask'
|
445 |
+
)
|
446 |
+
parser.add_argument(
|
447 |
+
"--mask-selection",
|
448 |
+
choices=["static", "uniform", "normal", "poisson"],
|
449 |
+
help="how to choose mask length",
|
450 |
+
)
|
451 |
+
parser.add_argument(
|
452 |
+
'--mask-other',
|
453 |
+
type=float,
|
454 |
+
help="secondary mask argument "
|
455 |
+
"(used for more complex distributions), "
|
456 |
+
"see help in compute_mask_indices"
|
457 |
+
)
|
458 |
+
parser.add_argument(
|
459 |
+
'--mask-min-space',
|
460 |
+
type=int,
|
461 |
+
help='min space between spans (if no overlap is enabled)'
|
462 |
+
)
|
463 |
+
|
464 |
+
# channel masking
|
465 |
+
parser.add_argument(
|
466 |
+
'--mask-channel-length',
|
467 |
+
type=int,
|
468 |
+
help='length of the mask for features (channels)'
|
469 |
+
)
|
470 |
+
parser.add_argument(
|
471 |
+
'--mask-channel-prob',
|
472 |
+
type=float,
|
473 |
+
help="probability of replacing a feature with 0"
|
474 |
+
)
|
475 |
+
parser.add_argument(
|
476 |
+
"--mask-channel-selection",
|
477 |
+
choices=["static", "uniform", "normal", "poisson"],
|
478 |
+
help="how to choose mask length for channel masking",
|
479 |
+
)
|
480 |
+
parser.add_argument(
|
481 |
+
'--mask-channel-other',
|
482 |
+
type=float,
|
483 |
+
help="secondary mask argument "
|
484 |
+
"(used for more complex distributions), "
|
485 |
+
"see help in compute_mask_indices"
|
486 |
+
)
|
487 |
+
parser.add_argument(
|
488 |
+
'--mask-channel-min-space',
|
489 |
+
type=int,
|
490 |
+
help='min space between spans (if no overlap is enabled)'
|
491 |
+
)
|
492 |
+
|
493 |
+
# abs positional embeddings
|
494 |
+
parser.add_argument(
|
495 |
+
'--conv-pos',
|
496 |
+
type=int,
|
497 |
+
help='number of filters for convolutional positional embeddings'
|
498 |
+
)
|
499 |
+
parser.add_argument(
|
500 |
+
'--conv-pos-groups',
|
501 |
+
type=int,
|
502 |
+
help='number of groups for convolutional positional embedding'
|
503 |
+
)
|
504 |
+
|
505 |
+
# codebook related
|
506 |
+
parser.add_argument(
|
507 |
+
"--use-codebook",
|
508 |
+
action="store_true",
|
509 |
+
help="whether to use codebook",
|
510 |
+
)
|
511 |
+
parser.add_argument(
|
512 |
+
"--codebook-prob",
|
513 |
+
type=float,
|
514 |
+
help="probability to use codebook",
|
515 |
+
)
|
516 |
+
parser.add_argument(
|
517 |
+
"--latent-vars",
|
518 |
+
type=int,
|
519 |
+
help="number of latent variables V in each group of the codebook",
|
520 |
+
)
|
521 |
+
parser.add_argument(
|
522 |
+
"--latent-groups",
|
523 |
+
type=int,
|
524 |
+
help="number of groups G of latent variables in the codebook",
|
525 |
+
)
|
526 |
+
parser.add_argument(
|
527 |
+
"--latent-dim",
|
528 |
+
type=int,
|
529 |
+
help="if > 0, uses this dimensionality for latent variables. "
|
530 |
+
"otherwise uses final_dim / latent_groups",
|
531 |
+
)
|
532 |
+
parser.add_argument(
|
533 |
+
"--latent-temp",
|
534 |
+
type=literal_eval,
|
535 |
+
help="temperature for latent variable sampling. "
|
536 |
+
"can be tuple of 3 values (start, end, decay)",
|
537 |
+
)
|
538 |
+
parser.add_argument(
|
539 |
+
"--quantizer-depth",
|
540 |
+
type=int,
|
541 |
+
help="number of quantizer layers",
|
542 |
+
)
|
543 |
+
parser.add_argument(
|
544 |
+
"--quantizer-factor",
|
545 |
+
type=int,
|
546 |
+
help="number of quantizer layers",
|
547 |
+
)
|
548 |
+
parser.add_argument(
|
549 |
+
"--get-code-distribution",
|
550 |
+
action='store_true',
|
551 |
+
help="whether to get the code distribution (for test)",
|
552 |
+
)
|
553 |
+
|
554 |
+
# relative pos enc
|
555 |
+
parser.add_argument(
|
556 |
+
"--relative-position-embedding",
|
557 |
+
action='store_true',
|
558 |
+
help="whether to use relative position embedding",
|
559 |
+
)
|
560 |
+
parser.add_argument(
|
561 |
+
"--num-buckets",
|
562 |
+
type=int,
|
563 |
+
default=320,
|
564 |
+
help="num of buckets for relative position embedding",
|
565 |
+
)
|
566 |
+
parser.add_argument(
|
567 |
+
"--max-distance",
|
568 |
+
type=int,
|
569 |
+
default=1280,
|
570 |
+
help="max distance for relative position embedding",
|
571 |
+
)
|
572 |
+
parser.add_argument(
|
573 |
+
"--encoder-max-relative-position",
|
574 |
+
type=int,
|
575 |
+
help="max distance for relative position embedding in encoder",
|
576 |
+
)
|
577 |
+
parser.add_argument(
|
578 |
+
"--decoder-max-relative-position",
|
579 |
+
type=int,
|
580 |
+
help="max distance for relative position embedding in decoder",
|
581 |
+
)
|
582 |
+
|
583 |
+
# hubert feature extractor
|
584 |
+
parser.add_argument(
|
585 |
+
"--conv-feature-layers",
|
586 |
+
type=str,
|
587 |
+
help= "string describing convolutional feature extraction "
|
588 |
+
"layers in form of a python list that contains "
|
589 |
+
"[(dim, kernel_size, stride), ...]",
|
590 |
+
)
|
591 |
+
parser.add_argument(
|
592 |
+
"--conv-bias",
|
593 |
+
action='store_true',
|
594 |
+
help="include bias in conv encoder",
|
595 |
+
)
|
596 |
+
parser.add_argument(
|
597 |
+
"--extractor-mode",
|
598 |
+
choices=["default", "layer_norm"],
|
599 |
+
help="mode for feature extractor. default has a single group "
|
600 |
+
"norm with d groups in the first conv block, whereas layer_norm "
|
601 |
+
"has layer norms in every block (meant to use with normalize=True)"
|
602 |
+
)
|
603 |
+
|
604 |
+
# others
|
605 |
+
parser.add_argument(
|
606 |
+
"--bert-init",
|
607 |
+
action='store_true',
|
608 |
+
help="initilize as bert",
|
609 |
+
)
|
610 |
+
parser.add_argument(
|
611 |
+
"--unb-enc-layer",
|
612 |
+
type=int,
|
613 |
+
default=-1,
|
614 |
+
help="which layer's output is used as the input of decoder",
|
615 |
+
)
|
616 |
+
|
617 |
+
# Encoder, Decoder
|
618 |
+
@classmethod
|
619 |
+
def build_encoder(cls, args, dictionary=None, embed_tokens=None):
|
620 |
+
return TransformerEncoder(args, dictionary, embed_tokens)
|
621 |
+
|
622 |
+
@classmethod
|
623 |
+
def build_decoder(cls, args):
|
624 |
+
return TransformerDecoder(args)
|
625 |
+
|
626 |
+
# Encoder Prenet
|
627 |
+
@classmethod
|
628 |
+
def build_text_encoder_prenet(cls, embed_tokens, args):
|
629 |
+
return TextEncoderPrenet(embed_tokens, args)
|
630 |
+
|
631 |
+
@classmethod
|
632 |
+
def build_speech_encoder_prenet(cls, args):
|
633 |
+
return SpeechEncoderPrenet(args)
|
634 |
+
|
635 |
+
# Decoder Prenet
|
636 |
+
@classmethod
|
637 |
+
def build_text_decoder_prenet(cls, embed_tokens, args):
|
638 |
+
return TextDecoderPrenet(embed_tokens, args)
|
639 |
+
|
640 |
+
@classmethod
|
641 |
+
def build_speech_decoder_prenet(cls, odim, args):
|
642 |
+
return SpeechDecoderPrenet(odim, args)
|
643 |
+
|
644 |
+
# Decoder Postnet
|
645 |
+
@classmethod
|
646 |
+
def build_text_decoder_postnet(cls, embed_tokens, dictionary, args):
|
647 |
+
return TextDecoderPostnet(embed_tokens, dictionary, args)
|
648 |
+
|
649 |
+
@classmethod
|
650 |
+
def build_speaker_decoder_postnet(cls, embed_dim, class_num, args):
|
651 |
+
return SpeakerDecoderPostnet(embed_dim, class_num, args)
|
652 |
+
|
653 |
+
@classmethod
|
654 |
+
def build_speech_decoder_postnet(cls, odim, args):
|
655 |
+
return SpeechDecoderPostnet(odim, args)
|
656 |
+
|
657 |
+
@classmethod
|
658 |
+
def build_speech_encoder_postnet(cls, dictionaries, args):
|
659 |
+
return SpeechEncoderPostnet(dictionaries, args)
|
660 |
+
|
661 |
+
@classmethod
|
662 |
+
def build_model(cls, args, task):
|
663 |
+
"""Build a new model instance."""
|
664 |
+
|
665 |
+
# make sure all arguments are present in older models
|
666 |
+
base_architecture(args)
|
667 |
+
|
668 |
+
def build_embedding(dictionary, embed_dim, max_num_embeddings=None):
|
669 |
+
num_embeddings = len(dictionary)
|
670 |
+
if max_num_embeddings is not None and isinstance(max_num_embeddings, int):
|
671 |
+
num_embeddings = min(num_embeddings, max_num_embeddings)
|
672 |
+
padding_idx = dictionary.pad()
|
673 |
+
return Embedding(num_embeddings, embed_dim, padding_idx)
|
674 |
+
|
675 |
+
if hasattr(args, "sid_pad_prenet") and args.sid_pad_prenet:
|
676 |
+
max_num_embeddings = 3 # <pad> at index 2
|
677 |
+
else:
|
678 |
+
max_num_embeddings = None
|
679 |
+
|
680 |
+
text_decoder_embed_tokens = build_embedding(
|
681 |
+
task.dicts["text"], args.decoder_embed_dim, max_num_embeddings
|
682 |
+
)
|
683 |
+
|
684 |
+
if args.share_input_output_embed:
|
685 |
+
text_encoder_embed_tokens = text_decoder_embed_tokens
|
686 |
+
else:
|
687 |
+
text_encoder_embed_tokens = build_embedding(
|
688 |
+
task.dicts["text"], args.encoder_embed_dim
|
689 |
+
)
|
690 |
+
|
691 |
+
speech_odim = args.speech_odim
|
692 |
+
if "text" in task.dicts:
|
693 |
+
encoder = cls.build_encoder(args, task.dicts["text"], text_encoder_embed_tokens)
|
694 |
+
else:
|
695 |
+
encoder = cls.build_encoder(args)
|
696 |
+
decoder = cls.build_decoder(args)
|
697 |
+
|
698 |
+
text_encoder_prenet = cls.build_text_encoder_prenet(text_encoder_embed_tokens, args)
|
699 |
+
speech_encoder_prenet = cls.build_speech_encoder_prenet(args)
|
700 |
+
|
701 |
+
text_decoder_prenet = cls.build_text_decoder_prenet(text_decoder_embed_tokens, args)
|
702 |
+
if getattr(args, "sid_pooling_layer", None) == "decoder-las":
|
703 |
+
speech_decoder_prenet = cls.build_speech_encoder_prenet(args)
|
704 |
+
else:
|
705 |
+
speech_decoder_prenet = cls.build_speech_decoder_prenet(speech_odim, args)
|
706 |
+
|
707 |
+
text_decoder_postnet = cls.build_text_decoder_postnet(text_decoder_embed_tokens, task.dicts['text'], args)
|
708 |
+
speech_decoder_postnet = cls.build_speech_decoder_postnet(speech_odim, args)
|
709 |
+
|
710 |
+
if getattr(args, "sid_t5_postnet", False):
|
711 |
+
speaker_decoder_postnet = None
|
712 |
+
else:
|
713 |
+
if task.t5_task == "s2c":
|
714 |
+
speaker_decoder_postnet = cls.build_speaker_decoder_postnet(args.sid_embed_dim, len(task.dicts['text']), args)
|
715 |
+
else:
|
716 |
+
speaker_decoder_postnet = None
|
717 |
+
|
718 |
+
if "hubert" in task.dicts:
|
719 |
+
speech_encoder_postnet = cls.build_speech_encoder_postnet(task.dicts['hubert'], args)
|
720 |
+
else:
|
721 |
+
speech_encoder_postnet = None
|
722 |
+
|
723 |
+
return cls(
|
724 |
+
args,
|
725 |
+
encoder, decoder,
|
726 |
+
text_encoder_prenet, speech_encoder_prenet,
|
727 |
+
text_decoder_prenet, speech_decoder_prenet,
|
728 |
+
text_decoder_postnet, speech_decoder_postnet,
|
729 |
+
speaker_decoder_postnet, speech_encoder_postnet,
|
730 |
+
)
|
731 |
+
|
732 |
+
def get_normalized_probs(
|
733 |
+
self,
|
734 |
+
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
|
735 |
+
log_probs: bool,
|
736 |
+
sample: Optional[Dict[str, Tensor]] = None,
|
737 |
+
):
|
738 |
+
# net_output['encoder_out'] is a (B, T, D) tensor
|
739 |
+
lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
740 |
+
lprobs.batch_first = True
|
741 |
+
return lprobs
|
742 |
+
|
743 |
+
def get_normalized_probs_for_ctc(self, net_output, log_probs):
|
744 |
+
"""Get normalized probabilities (or log probs) from a net's output."""
|
745 |
+
|
746 |
+
logits = net_output["encoder_out_for_ctc"][0]
|
747 |
+
if log_probs:
|
748 |
+
return utils.log_softmax(logits.float(), dim=-1)
|
749 |
+
else:
|
750 |
+
return utils.softmax(logits.float(), dim=-1)
|
751 |
+
|
752 |
+
def get_logits(self, net_output, is_masked=True):
|
753 |
+
if is_masked:
|
754 |
+
logits_list = net_output["logit_m_list"]
|
755 |
+
else:
|
756 |
+
logits_list = net_output["logit_u_list"]
|
757 |
+
logits_list = [x.float() for x in logits_list if x is not None]
|
758 |
+
return logits_list
|
759 |
+
|
760 |
+
def get_targets(self, sample, net_output, is_masked=True):
|
761 |
+
if "logit_m_list" in net_output:
|
762 |
+
logits_list = self.get_logits(net_output, is_masked)
|
763 |
+
targets_list = [
|
764 |
+
x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list
|
765 |
+
]
|
766 |
+
return targets_list
|
767 |
+
else:
|
768 |
+
return sample["target"]
|
769 |
+
|
770 |
+
def get_extra_losses(self, net_output):
|
771 |
+
extra_losses = []
|
772 |
+
names = []
|
773 |
+
|
774 |
+
if "features_pen" in net_output:
|
775 |
+
extra_losses.append(net_output["features_pen"])
|
776 |
+
names.append("features_pen")
|
777 |
+
|
778 |
+
if "prob_perplexity" in net_output:
|
779 |
+
extra_losses.append(
|
780 |
+
(net_output["num_vars"] - net_output["prob_perplexity"])
|
781 |
+
/ net_output["num_vars"]
|
782 |
+
)
|
783 |
+
names.append("prob_perplexity")
|
784 |
+
|
785 |
+
return extra_losses, names
|
786 |
+
|
787 |
+
def forward(self, source=None, src_tokens=None, src_lengths=None, prev_output_tokens=None, tgt_lengths=None, spkembs=None, target_list=None, task_name=None, padding_mask=None, only_hubert=False, only_ctc=False, feature_only=False, tgt_enc_layer=None, mask=True):
|
788 |
+
"""
|
789 |
+
The forward method inherited from the base class has a **kwargs
|
790 |
+
argument in its input, which is not supported in torchscript. This
|
791 |
+
method overwrites the forward method definition without **kwargs.
|
792 |
+
"""
|
793 |
+
assert source is not None or src_tokens is not None
|
794 |
+
# padding_mask is not none only when input is waveform
|
795 |
+
if source is None and padding_mask is None and not feature_only:
|
796 |
+
input_type = 'text'
|
797 |
+
else:
|
798 |
+
input_type = 'speech'
|
799 |
+
|
800 |
+
if prev_output_tokens is not None and len(prev_output_tokens.size()) == 2:
|
801 |
+
output_type = 'text'
|
802 |
+
codebook_out = {}
|
803 |
+
else:
|
804 |
+
output_type = 'speech'
|
805 |
+
|
806 |
+
if task_name is not None and task_name == "s2c":
|
807 |
+
if target_list is not None and target_list.size(1) == 1 and not getattr(self.args, "sid_t5_postnet", False):
|
808 |
+
sid_target = F.one_hot(target_list.squeeze(1), num_classes=self.speaker_decoder_postnet.class_num)
|
809 |
+
else:
|
810 |
+
sid_target = None
|
811 |
+
target_list = None
|
812 |
+
|
813 |
+
# Encoder Prenet
|
814 |
+
if input_type == 'text':
|
815 |
+
encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)
|
816 |
+
else:
|
817 |
+
if target_list is not None:
|
818 |
+
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, require_feat_pen=True, target_list=target_list, padding_mask=padding_mask, mask=mask)
|
819 |
+
encoder_input, features_pen, mask_indices, target_list = encoder_input
|
820 |
+
else:
|
821 |
+
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=self.training)
|
822 |
+
# shuffle a batch of inputs of encoder
|
823 |
+
if self.training and hasattr(self.args, "sid_shuffle_encoder_input") and getattr(self.args, "sid_shuffle_encoder_input", False):
|
824 |
+
shuffle_index = torch.randperm(encoder_padding_mask.size(1), device=encoder_padding_mask.device)
|
825 |
+
encoder_input = torch.index_select(encoder_input, 1, shuffle_index)
|
826 |
+
encoder_padding_mask = torch.index_select(encoder_padding_mask, 1, shuffle_index)
|
827 |
+
if getattr(self.args, "sid_encoder_cls", None) == "encoder":
|
828 |
+
prev_output_tokens = torch.zeros_like(prev_output_tokens)
|
829 |
+
encoder_input, encoder_padding_mask = self._integrate_with_speaker_cls(prev_output_tokens, encoder_input, encoder_padding_mask)
|
830 |
+
|
831 |
+
# Encoder: T x B x C
|
832 |
+
encoder_output = self.encoder(encoder_input, encoder_padding_mask, tgt_layer=tgt_enc_layer)
|
833 |
+
|
834 |
+
if task_name is not None and task_name == 'speech_pretrain' and feature_only:
|
835 |
+
return encoder_output["encoder_out"][0].transpose(0, 1)
|
836 |
+
|
837 |
+
if task_name is not None and task_name == 's2c':
|
838 |
+
if self.args.sid_pooling_layer == "encoder":
|
839 |
+
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1).mean(1), sid_target), None
|
840 |
+
elif self.args.sid_pooling_layer == "encoder-cls":
|
841 |
+
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1)[:,0], sid_target), None
|
842 |
+
elif self.args.sid_pooling_layer == "encoder-speaker" or getattr(self.args, "sid_decoder_speaker", False):
|
843 |
+
return self.speaker_decoder_postnet(encoder_output["encoder_out"][0].transpose(0, 1), sid_target), None
|
844 |
+
|
845 |
+
if target_list is not None:
|
846 |
+
hubert_results = self.hubert_layer(
|
847 |
+
encoder_output["encoder_out"][0].transpose(0, 1),
|
848 |
+
encoder_padding_mask,
|
849 |
+
mask_indices,
|
850 |
+
target_list
|
851 |
+
)
|
852 |
+
|
853 |
+
hubert_results['features_pen'] = features_pen
|
854 |
+
|
855 |
+
if "decoder_input" in encoder_output and encoder_output["decoder_input"][0] is not None:
|
856 |
+
# Change the encoder output to decoder input once set unb-enc-layer
|
857 |
+
encoder_output["encoder_out"] = encoder_output["decoder_input"]
|
858 |
+
|
859 |
+
if self.use_codebook:
|
860 |
+
q = self.quantizer(encoder_output["encoder_out"][0].transpose(0, 1))
|
861 |
+
|
862 |
+
# q["x"]: B x T x C
|
863 |
+
# Sample indexs according to the codebook prob
|
864 |
+
random_idx = torch.randperm(q["x"].size(1))[:int(q["x"].size(1) * self.codebook_prob)]
|
865 |
+
# Make weight for q
|
866 |
+
q_w = q["x"].new_zeros(q["x"].size(1))
|
867 |
+
q_w[random_idx] = 1.0
|
868 |
+
# Combine quantized codes and encoder output
|
869 |
+
encoder_output["encoder_out"][0] = (
|
870 |
+
q_w.view(-1, 1) * q["x"] + (- q_w + 1).view(-1, 1) * encoder_output["encoder_out"][0].transpose(0, 1)
|
871 |
+
).transpose(0, 1)
|
872 |
+
|
873 |
+
# encoder_output["encoder_out"][0] = q["x"].transpose(0, 1)
|
874 |
+
if output_type == 'speech':
|
875 |
+
hubert_results["prob_perplexity"] = q["prob_perplexity"]
|
876 |
+
hubert_results["code_perplexity"] = q["code_perplexity"]
|
877 |
+
hubert_results["num_vars"] = q["num_vars"]
|
878 |
+
hubert_results["temp"] = q["temp"]
|
879 |
+
elif output_type == 'text':
|
880 |
+
codebook_out["prob_perplexity"] = q["prob_perplexity"]
|
881 |
+
codebook_out["code_perplexity"] = q["code_perplexity"]
|
882 |
+
codebook_out["num_vars"] = q["num_vars"]
|
883 |
+
codebook_out["temp"] = q["temp"]
|
884 |
+
|
885 |
+
if only_hubert and target_list is not None:
|
886 |
+
return hubert_results, None
|
887 |
+
|
888 |
+
if only_ctc and task_name is not None and task_name == "s2t":
|
889 |
+
return None, encoder_output
|
890 |
+
elif not self.training and prev_output_tokens is None and task_name == "s2t" and task_name is not None:
|
891 |
+
return encoder_output
|
892 |
+
|
893 |
+
# Decoder Prenet
|
894 |
+
if output_type == 'text':
|
895 |
+
# _ is the incremental state
|
896 |
+
prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens)
|
897 |
+
if task_name is not None and task_name == 's2c':
|
898 |
+
prev_output_tokens = torch.zeros_like(prev_output_tokens)
|
899 |
+
else:
|
900 |
+
# integrate speaker embedding
|
901 |
+
if self.spk_embed_integration_type == "pre" and self.spk_embed_dim is not None:
|
902 |
+
# Decoder Prenet
|
903 |
+
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths, spkembs)
|
904 |
+
else:
|
905 |
+
if self.spk_embed_dim is not None:
|
906 |
+
encoder_output["encoder_out"] = [self._integrate_with_spk_embed(
|
907 |
+
encoder_output["encoder_out"][0].transpose(0, 1), spkembs
|
908 |
+
).transpose(0, 1)]
|
909 |
+
|
910 |
+
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(prev_output_tokens, tgt_lengths)
|
911 |
+
|
912 |
+
# BART Sequence Classification: cat <pad> + feature before decoder
|
913 |
+
if task_name is not None and task_name == 's2c' and self.args.sid_pooling_layer == "decoder-las":
|
914 |
+
decoder_feat_input, decoder_feat_mask = self.speech_decoder_prenet(src_tokens, src_lengths)
|
915 |
+
prev_output_tokens, tgt_mask = self._integrate_with_speaker_cls((prev_output_tokens, tgt_mask), decoder_feat_input, decoder_feat_mask, cls_first=False)
|
916 |
+
|
917 |
+
# SE predict masking to corresponding inputs and source speech replaces the prev_output_tokens as the input of decoder
|
918 |
+
if task_name is not None and task_name == "s2s" and getattr(self.args, "se_decoder_input", "previous_target") == "source":
|
919 |
+
prev_output_tokens, tgt_mask = self.speech_decoder_prenet(src_tokens, src_lengths)
|
920 |
+
|
921 |
+
# Decoder
|
922 |
+
decoder_output, extra = self.decoder(prev_output_tokens, tgt_mask, encoder_output,
|
923 |
+
full_context_alignment=getattr(self.args, "decoder_full_context_alignment", False),
|
924 |
+
alignment_layer=(-1 if target_list is None and output_type == 'speech' else None))
|
925 |
+
# Decoder Postnet
|
926 |
+
if task_name is not None and task_name == 's2c':
|
927 |
+
if not getattr(self.args, "sid_t5_postnet", False):
|
928 |
+
if self.args.sid_pooling_layer == "decoder":
|
929 |
+
return self.speaker_decoder_postnet(decoder_output.mean(1), sid_target), None
|
930 |
+
elif self.args.sid_pooling_layer == "decoder-las":
|
931 |
+
indices = (tgt_mask.eq(False).float().sum(1) - 1.0).type(torch.int64)
|
932 |
+
indices = indices.unsqueeze(1).unsqueeze(2).expand(-1, -1, decoder_output.size(2))
|
933 |
+
return self.speaker_decoder_postnet(decoder_output.gather(1, indices), sid_target), None
|
934 |
+
else:
|
935 |
+
return (self.text_decoder_postnet(decoder_output), None), encoder_output
|
936 |
+
|
937 |
+
# SE predict: masking, target, delta. Ensure reduction factor 1
|
938 |
+
if task_name is not None and task_name == 's2s' and getattr(self.args, "se_predict", None) is not None:
|
939 |
+
assert self.reduction_factor == 1, f"{self.reduction_factor} != 1"
|
940 |
+
before_outs, after_outs, logits = self.speech_decoder_postnet(decoder_output)
|
941 |
+
se_predict = getattr(self.args, "se_predict")
|
942 |
+
if se_predict == "masking":
|
943 |
+
before_outs = torch.sigmoid(before_outs) * src_tokens
|
944 |
+
after_outs = torch.sigmoid(after_outs) * src_tokens
|
945 |
+
return before_outs, after_outs, logits, extra['attn'][0]
|
946 |
+
elif se_predict == "target":
|
947 |
+
return before_outs, after_outs, logits, extra['attn'][0]
|
948 |
+
elif se_predict == "delta":
|
949 |
+
before_outs = before_outs - src_tokens
|
950 |
+
after_outs = after_outs - src_tokens
|
951 |
+
return before_outs, after_outs, logits, extra['attn'][0]
|
952 |
+
else:
|
953 |
+
raise ValueError(f"{se_predict} not in [masking, target, delta]")
|
954 |
+
|
955 |
+
if task_name is not None and task_name == 's2t':
|
956 |
+
#return self.text_decoder_postnet(decoder_output), None
|
957 |
+
return (self.text_decoder_postnet(decoder_output), None), encoder_output
|
958 |
+
if output_type == 'text':
|
959 |
+
return (self.text_decoder_postnet(decoder_output), None), codebook_out, encoder_output
|
960 |
+
else:
|
961 |
+
if target_list is not None:
|
962 |
+
return hubert_results, (self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],))
|
963 |
+
else:
|
964 |
+
return self.speech_decoder_postnet(decoder_output) + (extra['attn'][0],)
|
965 |
+
|
966 |
+
def _integrate_with_speaker_cls(self, pad_input, encoder_input, encoder_padding_mask=None, cls_first=True):
|
967 |
+
"""
|
968 |
+
encoder_input: [B, T, C]
|
969 |
+
encoder_padding_mask: [B, T]
|
970 |
+
"""
|
971 |
+
if hasattr(self, "text_decoder_prenet"):
|
972 |
+
if isinstance(pad_input, tuple):
|
973 |
+
repeat_cls_vector, repeat_cls_mask = pad_input
|
974 |
+
else:
|
975 |
+
repeat_cls_vector, repeat_cls_mask, _ = self.text_decoder_prenet(pad_input)
|
976 |
+
|
977 |
+
if encoder_padding_mask is not None:
|
978 |
+
bsz = encoder_input.size(0)
|
979 |
+
tsz = encoder_input.size(1)
|
980 |
+
encoder_padding_mask = encoder_input.new_zeros((bsz, tsz)) == 1.0
|
981 |
+
if repeat_cls_mask is None:
|
982 |
+
mask_size = (encoder_padding_mask.size(0), 1)
|
983 |
+
mask_type = encoder_padding_mask.dtype
|
984 |
+
repeat_cls_mask = encoder_padding_mask.new_zeros(mask_size) == 1.0
|
985 |
+
ret_encoder_padding_mask = torch.cat([repeat_cls_mask, encoder_padding_mask], dim=1)
|
986 |
+
|
987 |
+
if cls_first:
|
988 |
+
ret_encoder_input = torch.cat([repeat_cls_vector, encoder_input], dim=1)
|
989 |
+
else:
|
990 |
+
ret_encoder_input = torch.cat([encoder_input, encoder_input[:,-1:,:]], dim=1)
|
991 |
+
mask_size = (encoder_padding_mask.size(0), 1)
|
992 |
+
mask_type = encoder_padding_mask.dtype
|
993 |
+
repeat_cls_mask_ = encoder_padding_mask.new_ones(mask_size) == 1.0
|
994 |
+
encoder_padding_mask_ = torch.cat([encoder_padding_mask, repeat_cls_mask_], dim=1)
|
995 |
+
indices = encoder_padding_mask.eq(False).float().sum(1).type(torch.int64).unsqueeze(1)
|
996 |
+
indices_mask = torch.zeros_like(ret_encoder_padding_mask).scatter(1, indices, 1.0)
|
997 |
+
ret_encoder_input = ret_encoder_input * (1.0 - encoder_padding_mask_.type(ret_encoder_input.dtype).unsqueeze(2)) \
|
998 |
+
+ repeat_cls_vector * indices_mask.type(repeat_cls_vector.dtype).unsqueeze(2)
|
999 |
+
|
1000 |
+
return ret_encoder_input, ret_encoder_padding_mask
|
1001 |
+
|
1002 |
+
def _integrate_with_spk_embed(self, hs, spembs):
|
1003 |
+
"""Integrate speaker embedding with hidden states.
|
1004 |
+
Args:
|
1005 |
+
hs (Tensor): Batch of hidden state sequences (B, Tmax, adim).
|
1006 |
+
spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim).
|
1007 |
+
Returns:
|
1008 |
+
Tensor: Batch of integrated hidden state sequences (B, Tmax, adim)
|
1009 |
+
"""
|
1010 |
+
if self.spk_embed_integration_type == "add":
|
1011 |
+
# apply projection and then add to hidden states
|
1012 |
+
spembs = self.projection(F.normalize(spembs))
|
1013 |
+
hs = hs + spembs.unsqueeze(1)
|
1014 |
+
elif self.spk_embed_integration_type == "concat":
|
1015 |
+
# concat hidden states with spk embeds and then apply projection
|
1016 |
+
spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
|
1017 |
+
hs = self.projection(torch.cat([hs, spembs], dim=-1))
|
1018 |
+
else:
|
1019 |
+
raise NotImplementedError("support only add or concat.")
|
1020 |
+
|
1021 |
+
return hs
|
1022 |
+
|
1023 |
+
def load_state_dict(
|
1024 |
+
self,
|
1025 |
+
state_dict,
|
1026 |
+
strict=True,
|
1027 |
+
model_cfg=None,
|
1028 |
+
args=None,
|
1029 |
+
):
|
1030 |
+
"""NOT STRICT Copies parameters and buffers from *state_dict* into this module and
|
1031 |
+
its descendants.
|
1032 |
+
|
1033 |
+
Overrides the method in :class:`nn.Module`. Compared with that method
|
1034 |
+
this additionally "upgrades" *state_dicts* from old checkpoints.
|
1035 |
+
"""
|
1036 |
+
# self.prune_modules(model_cfg.modules_filter)
|
1037 |
+
model_dict_size = self.text_decoder_postnet.output_projection.out_features
|
1038 |
+
ckpt_dict_size = state_dict["text_decoder_postnet.output_projection.weight"].size(0)
|
1039 |
+
if model_dict_size != ckpt_dict_size:
|
1040 |
+
# reset dictionary-related modules, such as embedding table and encoder ctc embed
|
1041 |
+
logger.warn(f"not equal dictionary between model and checkpoint: {model_dict_size} vs {ckpt_dict_size}")
|
1042 |
+
logger.info(f"reset model dictionary with size of {model_dict_size}")
|
1043 |
+
removed_keys = [
|
1044 |
+
key for key in state_dict.keys() if any(
|
1045 |
+
key.startswith(previ) for previ in [
|
1046 |
+
"encoder.proj", "text_encoder_prenet", "text_decoder_prenet", "text_decoder_postnet"
|
1047 |
+
]
|
1048 |
+
)
|
1049 |
+
]
|
1050 |
+
for key in removed_keys:
|
1051 |
+
state_dict.pop(key, None)
|
1052 |
+
logger.info(f"removed loaded checkpoint: {key}")
|
1053 |
+
for m in self._modules.keys():
|
1054 |
+
m_state_dict = {
|
1055 |
+
key.replace(f"{m}.", ""): value for key, value in state_dict.items() if key.startswith(f"{m}.")
|
1056 |
+
}
|
1057 |
+
if hasattr(self, m):
|
1058 |
+
self._modules[m].load_state_dict(m_state_dict, False)
|
1059 |
+
return self
|
1060 |
+
|
1061 |
+
def prune_modules(self, modules_filter=None):
|
1062 |
+
"""Prune unused modules for specific tasks."""
|
1063 |
+
if modules_filter is None:
|
1064 |
+
return
|
1065 |
+
elif modules_filter == "s2c":
|
1066 |
+
if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
|
1067 |
+
if hasattr(self, "speech_decoder_prenet") and getattr(self.args, "sid_pooling_layer", None) != "decoder-las":
|
1068 |
+
del self.speech_decoder_prenet
|
1069 |
+
if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
|
1070 |
+
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
|
1071 |
+
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
|
1072 |
+
if hasattr(self.encoder, "proj"): self.encoder.proj = None
|
1073 |
+
if hasattr(self, "projection"): del self.projection
|
1074 |
+
if hasattr(self, "quantizer"): del self.quantizer
|
1075 |
+
if getattr(self.args, "sid_pooling_layer", "decoder").startswith("encoder") or getattr(self.args, "sid_decoder_speaker", False):
|
1076 |
+
if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
|
1077 |
+
if hasattr(self.decoder, "layers"): del self.decoder.layers
|
1078 |
+
if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
|
1079 |
+
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
|
1080 |
+
elif modules_filter == "s2s":
|
1081 |
+
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
|
1082 |
+
if hasattr(self, "text_encoder_prenet"): del self.text_encoder_prenet
|
1083 |
+
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
|
1084 |
+
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
|
1085 |
+
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
|
1086 |
+
if hasattr(self.encoder, "proj"): self.encoder.proj = None
|
1087 |
+
if hasattr(self, "projection"): del self.projection
|
1088 |
+
if hasattr(self, "quantizer"): del self.quantizer
|
1089 |
+
elif modules_filter == "t2s":
|
1090 |
+
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
|
1091 |
+
if hasattr(self, "speech_encoder_prenet"): del self.speech_encoder_prenet
|
1092 |
+
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
|
1093 |
+
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
|
1094 |
+
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
|
1095 |
+
if hasattr(self.encoder, "proj"): self.encoder.proj = None
|
1096 |
+
if hasattr(self, "projection"): del self.projection
|
1097 |
+
if hasattr(self, "quantizer"): del self.quantizer
|
1098 |
+
elif modules_filter == "s3prl":
|
1099 |
+
# remain the encoder and the pre/post net
|
1100 |
+
if hasattr(self.decoder, "dropout_module"): del self.decoder.dropout_module
|
1101 |
+
if hasattr(self.decoder, "layers"): del self.decoder.layers
|
1102 |
+
if hasattr(self.decoder, "layer_norm"): del self.decoder.layer_norm
|
1103 |
+
if hasattr(self, "speaker_decoder_postnet"): del self.speaker_decoder_postnet
|
1104 |
+
if hasattr(self, "text_decoder_prenet"): del self.text_decoder_prenet
|
1105 |
+
if hasattr(self, "text_decoder_postnet"): del self.text_decoder_postnet
|
1106 |
+
if hasattr(self, "speech_decoder_prenet"): del self.speech_decoder_prenet
|
1107 |
+
if hasattr(self, "speech_decoder_postnet"): del self.speech_decoder_postnet
|
1108 |
+
if hasattr(self, "speech_encoder_postnet"): del self.speech_encoder_postnet
|
1109 |
+
if hasattr(self.encoder, "proj"): self.encoder.proj = None
|
1110 |
+
if hasattr(self, "projection"): del self.projection
|
1111 |
+
if hasattr(self, "quantizer"): del self.quantizer
|
1112 |
+
|
1113 |
+
def forward_encoder_torchscript(self, net_input: Dict[str, Tensor]):
|
1114 |
+
"""A TorchScript-compatible version of forward.
|
1115 |
+
|
1116 |
+
Encoders which use additional arguments may want to override
|
1117 |
+
this method for TorchScript compatibility.
|
1118 |
+
"""
|
1119 |
+
if torch.jit.is_scripting():
|
1120 |
+
return self.forward_encoder(
|
1121 |
+
source=net_input["source"],
|
1122 |
+
padding_mask=net_input["padding_mask"]
|
1123 |
+
)
|
1124 |
+
else:
|
1125 |
+
return self.forward_encoder_non_torchscript(net_input)
|
1126 |
+
|
1127 |
+
@torch.jit.unused
|
1128 |
+
def forward_encoder_non_torchscript(self, net_input: Dict[str, Tensor]):
|
1129 |
+
encoder_input = {
|
1130 |
+
k: v for k, v in net_input.items() if k != "prev_output_tokens" and k != "task_name"
|
1131 |
+
}
|
1132 |
+
return self.forward_encoder(**encoder_input)
|
1133 |
+
|
1134 |
+
def forward_encoder(self, source, padding_mask=None):
|
1135 |
+
# Encoder Prenet
|
1136 |
+
encoder_input, encoder_padding_mask = self.speech_encoder_prenet(source, padding_mask=padding_mask, mask=False)
|
1137 |
+
|
1138 |
+
# Encoder
|
1139 |
+
encoder_output = self.encoder(encoder_input, encoder_padding_mask)
|
1140 |
+
|
1141 |
+
return encoder_output
|
1142 |
+
|
1143 |
+
def forward_text_encoder(self, src_tokens):
|
1144 |
+
# Text Encoder Prenet
|
1145 |
+
encoder_input, encoder_padding_mask = self.text_encoder_prenet(src_tokens)
|
1146 |
+
|
1147 |
+
# Encoder
|
1148 |
+
encoder_output = self.encoder(encoder_input, encoder_padding_mask)
|
1149 |
+
|
1150 |
+
return encoder_output
|
1151 |
+
|
1152 |
+
def forward_decoder(self, tokens, encoder_out, incremental_state):
|
1153 |
+
# Decoder Prenet
|
1154 |
+
prev_output_tokens, tgt_mask, incremental_state = self.text_decoder_prenet(tokens, incremental_state)
|
1155 |
+
|
1156 |
+
# Decoder
|
1157 |
+
decoder_output, extra = self.decoder(
|
1158 |
+
prev_output_tokens,
|
1159 |
+
tgt_mask,
|
1160 |
+
encoder_out=encoder_out,
|
1161 |
+
incremental_state=incremental_state,
|
1162 |
+
)
|
1163 |
+
|
1164 |
+
# Decoder Postnet
|
1165 |
+
return self.text_decoder_postnet(decoder_output), extra
|
1166 |
+
|
1167 |
+
def set_num_updates(self, num_updates):
|
1168 |
+
"""Set the number of parameters updates."""
|
1169 |
+
super().set_num_updates(num_updates)
|
1170 |
+
self.num_updates = num_updates
|
1171 |
+
|
1172 |
+
def generate_class(self, source, prev_output_tokens, **kwargs):
|
1173 |
+
encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])
|
1174 |
+
|
1175 |
+
prev_output_tokens, tgt_mask, _ = self.text_decoder_prenet(prev_output_tokens, {})
|
1176 |
+
prev_output_tokens = torch.zeros_like(prev_output_tokens) # s2c use zero vector as [CLS]
|
1177 |
+
|
1178 |
+
decoder_output, extra = self.decoder(
|
1179 |
+
prev_output_tokens,
|
1180 |
+
tgt_mask,
|
1181 |
+
encoder_out=encoder_out,
|
1182 |
+
)
|
1183 |
+
|
1184 |
+
decoder_out, embed = self.speaker_decoder_postnet(decoder_output.mean(1))
|
1185 |
+
|
1186 |
+
pred_class = decoder_out.argmax(1)
|
1187 |
+
return pred_class
|
1188 |
+
|
1189 |
+
def generate_speech(self, source=None, src_tokens=None, spkembs=None, **kwargs):
|
1190 |
+
assert source is not None or src_tokens is not None
|
1191 |
+
|
1192 |
+
threshold = kwargs.get("threshold", 0.5)
|
1193 |
+
minlenratio = kwargs.get("threshold", 0.0)
|
1194 |
+
|
1195 |
+
if source is None:
|
1196 |
+
assert src_tokens.size(0) == 1
|
1197 |
+
encoder_out = self.forward_text_encoder(src_tokens)
|
1198 |
+
maxlenratio = kwargs.get("threshold", 20.0)
|
1199 |
+
else:
|
1200 |
+
assert source.size(0) == 1
|
1201 |
+
encoder_out = self.forward_encoder(source, padding_mask=kwargs["padding_mask"])
|
1202 |
+
maxlenratio = kwargs.get("threshold", 10.0)
|
1203 |
+
|
1204 |
+
if spkembs is not None and self.spk_embed_integration_type != "pre":
|
1205 |
+
encoder_out["encoder_out"] = [self._integrate_with_spk_embed(
|
1206 |
+
encoder_out["encoder_out"][0].transpose(0, 1), spkembs
|
1207 |
+
).transpose(0, 1)]
|
1208 |
+
spkembs = None
|
1209 |
+
|
1210 |
+
maxlen = int(encoder_out["encoder_out"][0].size(0) * maxlenratio / self.reduction_factor)
|
1211 |
+
minlen = int(encoder_out["encoder_out"][0].size(0) * minlenratio / self.reduction_factor)
|
1212 |
+
|
1213 |
+
idx = 0
|
1214 |
+
ys = encoder_out["encoder_out"][0].new_zeros(1, 1, self.speech_decoder_postnet.odim)
|
1215 |
+
outs, probs = [], []
|
1216 |
+
|
1217 |
+
# forward decoder step-by-step
|
1218 |
+
if isinstance(self.decoder, FairseqIncrementalDecoder):
|
1219 |
+
incremental_states = {}
|
1220 |
+
else:
|
1221 |
+
incremental_states = None
|
1222 |
+
attns = []
|
1223 |
+
while True:
|
1224 |
+
# update index
|
1225 |
+
idx += 1
|
1226 |
+
# calculate output and stop prob at idx-th step
|
1227 |
+
decoder_in, _ = self.speech_decoder_prenet(ys, spkembs=spkembs)
|
1228 |
+
z, extra = self.decoder(decoder_in[:,-1:], None, encoder_out, incremental_states, alignment_layer=-1)
|
1229 |
+
outs += [self.speech_decoder_postnet.feat_out(z[0, -1]).view(self.reduction_factor, self.speech_decoder_postnet.odim)] # [(r, odim), ...]
|
1230 |
+
probs += [torch.sigmoid(self.speech_decoder_postnet.prob_out(z[0, -1]))] # [(r), ...]
|
1231 |
+
|
1232 |
+
# update next inputs
|
1233 |
+
ys = torch.cat((ys, outs[-1][-1].view(1, 1, self.speech_decoder_postnet.odim)), dim=1) # (1, idx + 1, odim)
|
1234 |
+
attns.append(torch.stack([att_l[0] for att_l in extra['attn'][0]], dim=0))
|
1235 |
+
# check whether to finish generation
|
1236 |
+
if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen:
|
1237 |
+
# check mininum length
|
1238 |
+
if idx < minlen:
|
1239 |
+
continue
|
1240 |
+
outs = (torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2)) # (L, odim) -> (1, L, odim) -> (1, odim, L)
|
1241 |
+
if self.speech_decoder_postnet.postnet is not None:
|
1242 |
+
outs = outs + self.speech_decoder_postnet.postnet(outs) # (1, odim, L)
|
1243 |
+
outs = outs.transpose(2, 1).squeeze(0) # (L, odim)
|
1244 |
+
probs = torch.cat(probs, dim=0)
|
1245 |
+
attn = torch.cat(attns, dim=2)
|
1246 |
+
break
|
1247 |
+
|
1248 |
+
if outs.size(0) == maxlen:
|
1249 |
+
logging.warning("output length reaches maximum length")
|
1250 |
+
return outs, probs, attn
|
1251 |
+
|
1252 |
+
|
1253 |
+
@register_model_architecture(model_name="artst_transformer", arch_name="artst_transformer")
|
1254 |
+
def base_architecture(args):
|
1255 |
+
# Transformer
|
1256 |
+
args.bert_init = getattr(args, "bert_init", False)
|
1257 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
1258 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 768 * 4)
|
1259 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
1260 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
1261 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
1262 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
|
1263 |
+
args.decoder_ffn_embed_dim = getattr(
|
1264 |
+
args, "decoder_ffn_embed_dim", args.encoder_ffn_embed_dim
|
1265 |
+
)
|
1266 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
1267 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 12)
|
1268 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
1269 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
1270 |
+
args.attention_dropout = getattr(args, "attention_dropout", args.dropout)
|
1271 |
+
args.activation_dropout = getattr(args, "activation_dropout", args.dropout)
|
1272 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
1273 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
|
1274 |
+
args.decoder_output_dim = getattr(
|
1275 |
+
args, "decoder_output_dim", args.decoder_embed_dim
|
1276 |
+
)
|
1277 |
+
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
|
1278 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
|
1279 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
|
1280 |
+
args.max_text_positions = getattr(args, "max_text_positions", DEFAULT_MAX_TEXT_POSITIONS)
|
1281 |
+
args.max_speech_positions = getattr(args, "max_speech_positions", DEFAULT_MAX_SPEECH_POSITIONS)
|
1282 |
+
|
1283 |
+
# Espnet related, including prenet, postnet
|
1284 |
+
args.eprenet_conv_layers = getattr(args, "eprenet_conv_layers", 0)
|
1285 |
+
args.eprenet_conv_filts = getattr(args, "eprenet_conv_filts", 0)
|
1286 |
+
args.eprenet_conv_chans = getattr(args, "eprenet_conv_chans", 0)
|
1287 |
+
args.use_batch_norm = getattr(args, "use_batch_norm", True)
|
1288 |
+
args.eprenet_dropout_rate = getattr(args, "eprenet_dropout_rate", 0.0)
|
1289 |
+
args.enc_use_scaled_pos_enc = getattr(args, "enc_use_scaled_pos_enc", True)
|
1290 |
+
args.dec_use_scaled_pos_enc = getattr(args, "dec_use_scaled_pos_enc", True)
|
1291 |
+
args.postnet_layers = getattr(args, "postnet_layers", 5)
|
1292 |
+
args.postnet_chans = getattr(args, "postnet_chans", 256)
|
1293 |
+
args.postnet_filts = getattr(args, "postnet_filts", 5)
|
1294 |
+
args.postnet_dropout_rate = getattr(args, "postnet_dropout_rate", 0.5)
|
1295 |
+
args.dprenet_dropout_rate = getattr(args, "dprenet_dropout_rate", 0.5)
|
1296 |
+
args.dprenet_layers = getattr(args, "dprenet_layers", 2)
|
1297 |
+
args.dprenet_units = getattr(args, "dprenet_units", 256)
|
1298 |
+
args.initial_encoder_alpha = getattr(args, "initial_encoder_alpha", 1.0)
|
1299 |
+
args.initial_decoder_alpha = getattr(args, "initial_decoder_alpha", 1.0)
|
1300 |
+
args.spk_embed_integration_type = getattr(args, "spk_embed_integration_type", "pre")
|
1301 |
+
args.spk_embed_dim = getattr(args, "spk_embed_dim", 512)
|
1302 |
+
args.encoder_reduction_factor = getattr(args, "encoder_reduction_factor", 1)
|
1303 |
+
args.reduction_factor = getattr(args, "reduction_factor", 2)
|
1304 |
+
args.transformer_enc_positional_dropout_rate = getattr(args, "transformer_enc_positional_dropout_rate", 0.1)
|
1305 |
+
args.transformer_dec_positional_dropout_rate = getattr(args, "transformer_dec_positional_dropout_rate", 0.1)
|
1306 |
+
args.layer_norm_eps = getattr(args, "layer_norm_eps", 1e-5)
|
1307 |
+
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)
|
1308 |
+
# Convolutional subsampler
|
1309 |
+
args.encoder_speech_prenet = getattr(args, "encoder_speech_prenet", "conv")
|
1310 |
+
args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5")
|
1311 |
+
args.conv_channels = getattr(args, "conv_channels", 1024)
|
1312 |
+
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
|
1313 |
+
|
1314 |
+
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
|
1315 |
+
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0)
|
1316 |
+
args.no_token_positional_embeddings = getattr(
|
1317 |
+
args, "no_token_positional_embeddings", False
|
1318 |
+
)
|
1319 |
+
args.adaptive_input = getattr(args, "adaptive_input", False)
|
1320 |
+
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
|
1321 |
+
args.share_input_output_embed = getattr(args, "share_input_output_embed", False)
|
1322 |
+
args.share_ctc_embed = getattr(args, "share_ctc_embed", False)
|
1323 |
+
args.freeze_encoder_updates = getattr(args, "freeze_encoder_updates", 0)
|
1324 |
+
args.freeze_decoder_updates = getattr(args, "freeze_decoder_updates", 0)
|
1325 |
+
args.no_freeze_encoder_layer = getattr(args, "no_freeze_encoder_layer", None)
|
1326 |
+
|
1327 |
+
## sid
|
1328 |
+
args.sid_embed_dim = getattr(args, "sid_embed_dim", 128)
|
1329 |
+
args.sid_pooling_layer = getattr(args, "sid_pooling_layer", "decoder")
|
1330 |
+
args.softmax_scale = getattr(args, "softmax_scale", 1)
|
1331 |
+
args.softmax_margin = getattr(args, "softmax_margin", 0)
|
1332 |
+
args.softmax_easy_margin = getattr(args, "softmax_easy_margin", False)
|
1333 |
+
args.modules_filter = getattr(args, "modules_filter", None)
|
1334 |
+
|
1335 |
+
## Hubert
|
1336 |
+
args.conv_pos = getattr(args, "conv_pos", 128)
|
1337 |
+
args.conv_pos_groups = getattr(args, "conv_pos_groups", 16)
|
1338 |
+
args.target_glu = getattr(args, "target_glu", False)
|
1339 |
+
args.logit_temp = getattr(args, "logit_temp", 0.1)
|
1340 |
+
args.final_dim = getattr(args, "final_dim", 256)
|
1341 |
+
args.untie_final_proj = getattr(args, "untie_final_proj", True)
|
1342 |
+
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.1)
|
1343 |
+
args.use_sent_enc_layer = getattr(args, "use_sent_enc_layer", True)
|
1344 |
+
# hubert feature extractor
|
1345 |
+
args.extractor_mode = getattr(args, "extractor_mode", "default")
|
1346 |
+
args.conv_feature_layers = getattr(args, "conv_feature_layers", "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2")
|
1347 |
+
args.conv_bias = getattr(args, "conv_bias", False)
|
1348 |
+
# mask
|
1349 |
+
args.hubert_mask_length = getattr(args, "hubert_mask_length", 10)
|
1350 |
+
args.mask_prob = getattr(args, "mask_prob", 0.0)
|
1351 |
+
args.mask_selection = getattr(args, "mask_selection", "static")
|
1352 |
+
args.mask_other = getattr(args, "mask_other", 0)
|
1353 |
+
args.no_mask_overlap = getattr(args, "no_mask_overlap", False)
|
1354 |
+
args.mask_min_space = getattr(args, "mask_min_space", 1)
|
1355 |
+
# channel mask
|
1356 |
+
args.mask_channel_length = getattr(args, "mask_channel_length", 10)
|
1357 |
+
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.0)
|
1358 |
+
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
|
1359 |
+
args.mask_channel_other = getattr(args, "mask_channel_other", 0)
|
1360 |
+
args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False)
|
1361 |
+
args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1)
|
1362 |
+
# loss computation
|
1363 |
+
args.skip_masked = getattr(args, "skip_masked", False)
|
1364 |
+
args.skip_nomask = getattr(args, "skip_nomask", False)
|
1365 |
+
# conv Pos
|
1366 |
+
args.use_conv_pos = getattr(args, "use_conv_pos", False)
|
1367 |
+
args.use_sinc_pos = getattr(args, "use_sinc_pos", False)
|
1368 |
+
|
1369 |
+
# codebook
|
1370 |
+
args.use_codebook = getattr(args, "use_codebook", False)
|
1371 |
+
args.latent_vars = getattr(args, "latent_vars", 100)
|
1372 |
+
args.latent_groups = getattr(args, "latent_groups", 2)
|
1373 |
+
args.latent_dim = getattr(args, "latent_dim", 0)
|
1374 |
+
args.latent_temp = getattr(args, "latent_temp", (2, 0.5, 0.999995))
|
1375 |
+
args.quantizer_depth = getattr(args, "quantizer_depth", 1)
|
1376 |
+
args.quantizer_factor = getattr(args, "quantizer_factor", 3)
|
1377 |
+
args.codebook_prob = getattr(args, "codebook_prob", 0.5)
|
1378 |
+
|
1379 |
+
# Relative pos embed
|
1380 |
+
args.relative_position_embedding = getattr(args, "relative_position_embedding", False)
|
1381 |
+
args.num_buckets = getattr(args, "num_buckets", 320)
|
1382 |
+
args.max_distance = getattr(args, "max_distance", 1280)
|
1383 |
+
args.encoder_max_relative_position = getattr(args, "encoder_max_relative_position", 160)
|
1384 |
+
args.decoder_max_relative_position = getattr(args, "decoder_max_relative_position", 160)
|
1385 |
+
|
1386 |
+
@register_model_architecture("artst_transformer", "artst_transformer_base")
|
1387 |
+
def artst_transformer_base(args):
|
1388 |
+
args.use_conv_pos = getattr(args, "use_conv_pos", True)
|
1389 |
+
args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
|
1390 |
+
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
1391 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
1392 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
1393 |
+
args.layer_norm_first = getattr(args, "layer_norm_first", False)
|
1394 |
+
args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
|
1395 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
1396 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
1397 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
1398 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.05)
|
1399 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.05)
|
1400 |
+
args.mask_prob = getattr(args, "mask_prob", 0.80)
|
1401 |
+
base_architecture(args)
|
1402 |
+
|
1403 |
+
@register_model_architecture("artst_transformer", "artst_transformer_large")
|
1404 |
+
def artst_transformer_large(args):
|
1405 |
+
args.use_conv_pos = getattr(args, "use_conv_pos", True)
|
1406 |
+
args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
|
1407 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
1408 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
1409 |
+
args.layer_norm_first = getattr(args, "layer_norm_first", True)
|
1410 |
+
args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
|
1411 |
+
args.dropout = getattr(args, "dropout", 0.0)
|
1412 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
1413 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.0)
|
1414 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
|
1415 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0)
|
1416 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
1417 |
+
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
1418 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
1419 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
1420 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
1421 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
1422 |
+
args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0)
|
1423 |
+
args.extractor_mode = getattr(args, "extractor_mode", "layer_norm")
|
1424 |
+
args.final_dim = getattr(args, "final_dim", 768)
|
1425 |
+
args.mask_prob = getattr(args, "mask_prob", 0.80)
|
1426 |
+
base_architecture(args)
|
1427 |
+
|
1428 |
+
@register_model_architecture("artst_transformer", "artst_transformer_base_asr")
|
1429 |
+
def artst_transformer_base_asr(args):
|
1430 |
+
args.use_conv_pos = getattr(args, "use_conv_pos", True)
|
1431 |
+
args.use_sinc_pos = getattr(args, "use_sinc_pos", True)
|
1432 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
1433 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
|
1434 |
+
args.layer_norm_first = getattr(args, "layer_norm_first", False)
|
1435 |
+
args.relative_position_embedding = getattr(args, "relative_position_embedding", True)
|
1436 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
1437 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.1)
|
1438 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
1439 |
+
args.feature_grad_mult = getattr(args, "feature_grad_mult", 0.0)
|
1440 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.1)
|
1441 |
+
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.1)
|
1442 |
+
args.mask_prob = getattr(args, "mask_prob", 0.75)
|
1443 |
+
args.mask_selection = getattr(args, "mask_selection", "static")
|
1444 |
+
args.mask_channel_length = getattr(args, "mask_channel_length", 64)
|
1445 |
+
args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5)
|
1446 |
+
args.mask_channel_selection = getattr(args, "mask_channel_selection", "static")
|
1447 |
+
args.max_text_positions = getattr(args, "max_text_positions", 600)
|
1448 |
+
base_architecture(args)
|
artst/models/modules/__init__.py
ADDED
File without changes
|
artst/models/modules/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (150 Bytes). View file
|
|
artst/models/modules/__pycache__/decoder.cpython-38.pyc
ADDED
Binary file (8.71 kB). View file
|
|
artst/models/modules/__pycache__/encoder.cpython-38.pyc
ADDED
Binary file (8.84 kB). View file
|
|
artst/models/modules/__pycache__/multihead_attention.cpython-38.pyc
ADDED
Binary file (10.6 kB). View file
|
|
artst/models/modules/__pycache__/speaker_decoder_postnet.cpython-38.pyc
ADDED
Binary file (6.16 kB). View file
|
|
artst/models/modules/__pycache__/speech_decoder_postnet.cpython-38.pyc
ADDED
Binary file (2.05 kB). View file
|
|
artst/models/modules/__pycache__/speech_decoder_prenet.cpython-38.pyc
ADDED
Binary file (3.54 kB). View file
|
|
artst/models/modules/__pycache__/speech_encoder_postnet.cpython-38.pyc
ADDED
Binary file (4.04 kB). View file
|
|
artst/models/modules/__pycache__/speech_encoder_prenet.cpython-38.pyc
ADDED
Binary file (10.2 kB). View file
|
|