litagin commited on
Commit
70c3683
1 Parent(s): d6a59a3
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +44 -0
  2. app.py +253 -0
  3. bert/bert_models.json +14 -0
  4. bert/chinese-roberta-wwm-ext-large/.gitattributes +9 -0
  5. bert/chinese-roberta-wwm-ext-large/README.md +57 -0
  6. bert/chinese-roberta-wwm-ext-large/added_tokens.json +1 -0
  7. bert/chinese-roberta-wwm-ext-large/config.json +28 -0
  8. bert/chinese-roberta-wwm-ext-large/pytorch_model.bin +3 -0
  9. bert/chinese-roberta-wwm-ext-large/special_tokens_map.json +1 -0
  10. bert/chinese-roberta-wwm-ext-large/tokenizer.json +0 -0
  11. bert/chinese-roberta-wwm-ext-large/tokenizer_config.json +1 -0
  12. bert/chinese-roberta-wwm-ext-large/vocab.txt +0 -0
  13. bert/deberta-v2-large-japanese-char-wwm/.gitattributes +34 -0
  14. bert/deberta-v2-large-japanese-char-wwm/README.md +89 -0
  15. bert/deberta-v2-large-japanese-char-wwm/config.json +37 -0
  16. bert/deberta-v2-large-japanese-char-wwm/pytorch_model.bin +3 -0
  17. bert/deberta-v2-large-japanese-char-wwm/special_tokens_map.json +7 -0
  18. bert/deberta-v2-large-japanese-char-wwm/tokenizer_config.json +19 -0
  19. bert/deberta-v2-large-japanese-char-wwm/vocab.txt +0 -0
  20. bert/deberta-v3-large/.gitattributes +27 -0
  21. bert/deberta-v3-large/README.md +93 -0
  22. bert/deberta-v3-large/config.json +22 -0
  23. bert/deberta-v3-large/generator_config.json +22 -0
  24. bert/deberta-v3-large/pytorch_model.bin +3 -0
  25. bert/deberta-v3-large/pytorch_model.bin.bin +3 -0
  26. bert/deberta-v3-large/spm.model +3 -0
  27. bert/deberta-v3-large/tokenizer_config.json +4 -0
  28. chupa_examples.txt +0 -0
  29. model_assets/chupa_1/chupa_1spk_e1000_s194312.safetensors +3 -0
  30. model_assets/chupa_1/config.json +87 -0
  31. model_assets/chupa_1/style_vectors.npy +3 -0
  32. requirements.txt +23 -0
  33. style_bert_vits2/.editorconfig +15 -0
  34. style_bert_vits2/__init__.py +0 -0
  35. style_bert_vits2/constants.py +48 -0
  36. style_bert_vits2/logging.py +15 -0
  37. style_bert_vits2/models/__init__.py +0 -0
  38. style_bert_vits2/models/attentions.py +491 -0
  39. style_bert_vits2/models/commons.py +223 -0
  40. style_bert_vits2/models/hyper_parameters.py +129 -0
  41. style_bert_vits2/models/infer.py +308 -0
  42. style_bert_vits2/models/models.py +1102 -0
  43. style_bert_vits2/models/models_jp_extra.py +1157 -0
  44. style_bert_vits2/models/modules.py +642 -0
  45. style_bert_vits2/models/monotonic_alignment.py +89 -0
  46. style_bert_vits2/models/transforms.py +215 -0
  47. style_bert_vits2/models/utils/__init__.py +264 -0
  48. style_bert_vits2/models/utils/checkpoints.py +202 -0
  49. style_bert_vits2/models/utils/safetensors.py +91 -0
  50. style_bert_vits2/nlp/__init__.py +120 -0
.gitignore ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ venv/
3
+ .venv/
4
+ dist/
5
+ .coverage
6
+ .ipynb_checkpoints/
7
+ .ruff_cache/
8
+
9
+ /*.yml
10
+ !/default_config.yml
11
+ # /bert/*/*.bin
12
+ # /bert/*/*.h5
13
+ # /bert/*/*.model
14
+ # /bert/*/*.safetensors
15
+ # /bert/*/*.msgpack
16
+
17
+ /configs/paths.yml
18
+
19
+ /pretrained/*.safetensors
20
+ /pretrained/*.pth
21
+
22
+ /pretrained_jp_extra/*.safetensors
23
+ /pretrained_jp_extra/*.pth
24
+
25
+ /slm/*/*.bin
26
+
27
+ /scripts/test/
28
+ /scripts/lib/
29
+ /scripts/Style-Bert-VITS2/
30
+ /scripts/sbv2/
31
+ *.zip
32
+ *.csv
33
+ *.bak
34
+ /mos_results/
35
+
36
+ safetensors.ipynb
37
+ *.wav
38
+ /static/
39
+
40
+ # pyopenjtalk's dictionary
41
+ *.dic
42
+
43
+ playground.ipynb
44
+ playgrounds/
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from pathlib import Path
3
+ import gradio as gr
4
+ import random
5
+ from style_bert_vits2.constants import (
6
+ DEFAULT_LENGTH,
7
+ DEFAULT_LINE_SPLIT,
8
+ DEFAULT_NOISE,
9
+ DEFAULT_NOISEW,
10
+ DEFAULT_SPLIT_INTERVAL,
11
+ )
12
+ from style_bert_vits2.logging import logger
13
+ from style_bert_vits2.models.infer import InvalidToneError
14
+ from style_bert_vits2.nlp.japanese import pyopenjtalk_worker as pyopenjtalk
15
+ from style_bert_vits2.tts_model import TTSModelHolder
16
+
17
+
18
+ pyopenjtalk.initialize_worker()
19
+
20
+ example_file = "chupa_examples.txt"
21
+
22
+ initial_text = (
23
+ "ちゅぱ、ちゅるる、ぢゅ、んく、れーれゅれろれろれろ、じゅぽぽぽぽぽ……ちゅううう!"
24
+ )
25
+
26
+ with open(example_file, "r", encoding="utf-8") as f:
27
+ examples = f.read().splitlines()
28
+
29
+
30
+ def get_random_text() -> str:
31
+ return random.choice(examples)
32
+
33
+
34
+ initial_md = """
35
+ # チュパ音合成デモ
36
+
37
+ 2024-07-07: initial ver
38
+ """
39
+
40
+
41
+ def make_interactive():
42
+ return gr.update(interactive=True, value="音声合成")
43
+
44
+
45
+ def make_non_interactive():
46
+ return gr.update(interactive=False, value="音声合成(モデルをロードしてください)")
47
+
48
+
49
+ def gr_util(item):
50
+ if item == "プリセットから選ぶ":
51
+ return (gr.update(visible=True), gr.Audio(visible=False, value=None))
52
+ else:
53
+ return (gr.update(visible=False), gr.update(visible=True))
54
+
55
+
56
+ def create_inference_app(model_holder: TTSModelHolder) -> gr.Blocks:
57
+ def tts_fn(
58
+ model_name,
59
+ model_path,
60
+ text,
61
+ language,
62
+ sdp_ratio,
63
+ noise_scale,
64
+ noise_scale_w,
65
+ length_scale,
66
+ line_split,
67
+ split_interval,
68
+ speaker,
69
+ ):
70
+ model_holder.get_model(model_name, model_path)
71
+ assert model_holder.current_model is not None
72
+
73
+ speaker_id = model_holder.current_model.spk2id[speaker]
74
+
75
+ start_time = datetime.datetime.now()
76
+
77
+ try:
78
+ sr, audio = model_holder.current_model.infer(
79
+ text=text,
80
+ language=language,
81
+ sdp_ratio=sdp_ratio,
82
+ noise=noise_scale,
83
+ noise_w=noise_scale_w,
84
+ length=length_scale,
85
+ line_split=line_split,
86
+ split_interval=split_interval,
87
+ speaker_id=speaker_id,
88
+ )
89
+ except InvalidToneError as e:
90
+ logger.error(f"Tone error: {e}")
91
+ return f"Error: アクセント指定が不正です:\n{e}", None
92
+ except ValueError as e:
93
+ logger.error(f"Value error: {e}")
94
+ return f"Error: {e}", None
95
+
96
+ end_time = datetime.datetime.now()
97
+ duration = (end_time - start_time).total_seconds()
98
+
99
+ message = f"Success, time: {duration} seconds."
100
+ return message, (sr, audio)
101
+
102
+ def get_model_files(model_name: str):
103
+ return [str(f) for f in model_holder.model_files_dict[model_name]]
104
+
105
+ model_names = model_holder.model_names
106
+ if len(model_names) == 0:
107
+ logger.error(
108
+ f"モデルが見つかりませんでした。{model_holder.root_dir}にモデルを置いてください。"
109
+ )
110
+ with gr.Blocks() as app:
111
+ gr.Markdown(
112
+ f"Error: モデルが見つかりませんでした。{model_holder.root_dir}にモデルを置いてください。"
113
+ )
114
+ return app
115
+
116
+ initial_pth_files = get_model_files(model_names[0])
117
+ model = model_holder.get_model(model_names[0], initial_pth_files[0])
118
+ speakers = list(model.spk2id.keys())
119
+
120
+ with gr.Blocks(theme="ParityError/Anime") as app:
121
+ gr.Markdown(initial_md)
122
+ with gr.Row():
123
+ with gr.Column():
124
+ with gr.Row():
125
+ with gr.Column(scale=3):
126
+ model_name = gr.Dropdown(
127
+ label="モデル一覧",
128
+ choices=model_names,
129
+ value=model_names[0],
130
+ )
131
+ model_path = gr.Dropdown(
132
+ label="モデルファイル",
133
+ choices=initial_pth_files,
134
+ value=initial_pth_files[0],
135
+ )
136
+ refresh_button = gr.Button("更新", scale=1, visible=False)
137
+ load_button = gr.Button("ロード", scale=1, variant="primary")
138
+ with gr.Row():
139
+ text_input = gr.TextArea(
140
+ label="テキスト", value=initial_text, scale=3
141
+ )
142
+ random_button = gr.Button("例から選ぶ 🎲", scale=1)
143
+ random_button.click(get_random_text, outputs=[text_input])
144
+ with gr.Row():
145
+ length_scale = gr.Slider(
146
+ minimum=0.1,
147
+ maximum=2,
148
+ value=DEFAULT_LENGTH,
149
+ step=0.1,
150
+ label="生成音声の長さ(Length)",
151
+ )
152
+ sdp_ratio = gr.Slider(
153
+ minimum=0,
154
+ maximum=1,
155
+ value=1,
156
+ step=0.1,
157
+ label="SDP Ratio",
158
+ )
159
+ line_split = gr.Checkbox(
160
+ label="改行で分けて生成(分けたほうが感情が乗ります)",
161
+ value=DEFAULT_LINE_SPLIT,
162
+ visible=False,
163
+ )
164
+ split_interval = gr.Slider(
165
+ minimum=0.0,
166
+ maximum=2,
167
+ value=DEFAULT_SPLIT_INTERVAL,
168
+ step=0.1,
169
+ label="改行ごとに挟む無音の長さ(秒)",
170
+ )
171
+ line_split.change(
172
+ lambda x: (gr.Slider(visible=x)),
173
+ inputs=[line_split],
174
+ outputs=[split_interval],
175
+ )
176
+ language = gr.Dropdown(
177
+ choices=["JP"], value="JP", label="Language", visible=False
178
+ )
179
+ speaker = gr.Dropdown(label="話者", choices=speakers, value=speakers[0])
180
+ with gr.Accordion(label="詳細設定", open=True):
181
+ noise_scale = gr.Slider(
182
+ minimum=0.1,
183
+ maximum=2,
184
+ value=DEFAULT_NOISE,
185
+ step=0.1,
186
+ label="Noise",
187
+ )
188
+ noise_scale_w = gr.Slider(
189
+ minimum=0.1,
190
+ maximum=2,
191
+ value=DEFAULT_NOISEW,
192
+ step=0.1,
193
+ label="Noise_W",
194
+ )
195
+ with gr.Column():
196
+ tts_button = gr.Button("音声合成", variant="primary")
197
+ text_output = gr.Textbox(label="情報")
198
+ audio_output = gr.Audio(label="結果")
199
+
200
+ tts_button.click(
201
+ tts_fn,
202
+ inputs=[
203
+ model_name,
204
+ model_path,
205
+ text_input,
206
+ language,
207
+ sdp_ratio,
208
+ noise_scale,
209
+ noise_scale_w,
210
+ length_scale,
211
+ line_split,
212
+ split_interval,
213
+ speaker,
214
+ ],
215
+ outputs=[text_output, audio_output],
216
+ )
217
+
218
+ model_name.change(
219
+ model_holder.update_model_files_for_gradio,
220
+ inputs=[model_name],
221
+ outputs=[model_path],
222
+ )
223
+
224
+ model_path.change(make_non_interactive, outputs=[tts_button])
225
+
226
+ refresh_button.click(
227
+ model_holder.update_model_names_for_gradio,
228
+ outputs=[model_name, model_path, tts_button],
229
+ )
230
+ style = gr.Dropdown(label="スタイル", choices=[], visible=False)
231
+
232
+ load_button.click(
233
+ model_holder.get_model_for_gradio,
234
+ inputs=[model_name, model_path],
235
+ outputs=[style, tts_button, speaker],
236
+ )
237
+
238
+ return app
239
+
240
+
241
+ if __name__ == "__main__":
242
+ import torch
243
+
244
+ from style_bert_vits2.constants import Languages
245
+ from style_bert_vits2.nlp import bert_models
246
+
247
+ bert_models.load_model(Languages.JP)
248
+ bert_models.load_tokenizer(Languages.JP)
249
+
250
+ device = "cuda" if torch.cuda.is_available() else "cpu"
251
+ model_holder = TTSModelHolder(Path("model_assets"), device)
252
+ app = create_inference_app(model_holder)
253
+ app.launch(inbrowser=True)
bert/bert_models.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "deberta-v2-large-japanese-char-wwm": {
3
+ "repo_id": "ku-nlp/deberta-v2-large-japanese-char-wwm",
4
+ "files": ["pytorch_model.bin"]
5
+ },
6
+ "chinese-roberta-wwm-ext-large": {
7
+ "repo_id": "hfl/chinese-roberta-wwm-ext-large",
8
+ "files": ["pytorch_model.bin"]
9
+ },
10
+ "deberta-v3-large": {
11
+ "repo_id": "microsoft/deberta-v3-large",
12
+ "files": ["spm.model", "pytorch_model.bin"]
13
+ }
14
+ }
bert/chinese-roberta-wwm-ext-large/.gitattributes ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
2
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.h5 filter=lfs diff=lfs merge=lfs -text
5
+ *.tflite filter=lfs diff=lfs merge=lfs -text
6
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.ot filter=lfs diff=lfs merge=lfs -text
8
+ *.onnx filter=lfs diff=lfs merge=lfs -text
9
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
bert/chinese-roberta-wwm-ext-large/README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - zh
4
+ tags:
5
+ - bert
6
+ license: "apache-2.0"
7
+ ---
8
+
9
+ # Please use 'Bert' related functions to load this model!
10
+
11
+ ## Chinese BERT with Whole Word Masking
12
+ For further accelerating Chinese natural language processing, we provide **Chinese pre-trained BERT with Whole Word Masking**.
13
+
14
+ **[Pre-Training with Whole Word Masking for Chinese BERT](https://arxiv.org/abs/1906.08101)**
15
+ Yiming Cui, Wanxiang Che, Ting Liu, Bing Qin, Ziqing Yang, Shijin Wang, Guoping Hu
16
+
17
+ This repository is developed based on:https://github.com/google-research/bert
18
+
19
+ You may also interested in,
20
+ - Chinese BERT series: https://github.com/ymcui/Chinese-BERT-wwm
21
+ - Chinese MacBERT: https://github.com/ymcui/MacBERT
22
+ - Chinese ELECTRA: https://github.com/ymcui/Chinese-ELECTRA
23
+ - Chinese XLNet: https://github.com/ymcui/Chinese-XLNet
24
+ - Knowledge Distillation Toolkit - TextBrewer: https://github.com/airaria/TextBrewer
25
+
26
+ More resources by HFL: https://github.com/ymcui/HFL-Anthology
27
+
28
+ ## Citation
29
+ If you find the technical report or resource is useful, please cite the following technical report in your paper.
30
+ - Primary: https://arxiv.org/abs/2004.13922
31
+ ```
32
+ @inproceedings{cui-etal-2020-revisiting,
33
+ title = "Revisiting Pre-Trained Models for {C}hinese Natural Language Processing",
34
+ author = "Cui, Yiming and
35
+ Che, Wanxiang and
36
+ Liu, Ting and
37
+ Qin, Bing and
38
+ Wang, Shijin and
39
+ Hu, Guoping",
40
+ booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings",
41
+ month = nov,
42
+ year = "2020",
43
+ address = "Online",
44
+ publisher = "Association for Computational Linguistics",
45
+ url = "https://www.aclweb.org/anthology/2020.findings-emnlp.58",
46
+ pages = "657--668",
47
+ }
48
+ ```
49
+ - Secondary: https://arxiv.org/abs/1906.08101
50
+ ```
51
+ @article{chinese-bert-wwm,
52
+ title={Pre-Training with Whole Word Masking for Chinese BERT},
53
+ author={Cui, Yiming and Che, Wanxiang and Liu, Ting and Qin, Bing and Yang, Ziqing and Wang, Shijin and Hu, Guoping},
54
+ journal={arXiv preprint arXiv:1906.08101},
55
+ year={2019}
56
+ }
57
+ ```
bert/chinese-roberta-wwm-ext-large/added_tokens.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
bert/chinese-roberta-wwm-ext-large/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "directionality": "bidi",
8
+ "eos_token_id": 2,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 4096,
14
+ "layer_norm_eps": 1e-12,
15
+ "max_position_embeddings": 512,
16
+ "model_type": "bert",
17
+ "num_attention_heads": 16,
18
+ "num_hidden_layers": 24,
19
+ "output_past": true,
20
+ "pad_token_id": 0,
21
+ "pooler_fc_size": 768,
22
+ "pooler_num_attention_heads": 12,
23
+ "pooler_num_fc_layers": 3,
24
+ "pooler_size_per_head": 128,
25
+ "pooler_type": "first_token_transform",
26
+ "type_vocab_size": 2,
27
+ "vocab_size": 21128
28
+ }
bert/chinese-roberta-wwm-ext-large/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ac62d49144d770c5ca9a5d1d3039c4995665a080febe63198189857c6bd11cd
3
+ size 1306484351
bert/chinese-roberta-wwm-ext-large/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
bert/chinese-roberta-wwm-ext-large/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
bert/chinese-roberta-wwm-ext-large/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"init_inputs": []}
bert/chinese-roberta-wwm-ext-large/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
bert/deberta-v2-large-japanese-char-wwm/.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
bert/deberta-v2-large-japanese-char-wwm/README.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: ja
3
+ license: cc-by-sa-4.0
4
+ library_name: transformers
5
+ tags:
6
+ - deberta
7
+ - deberta-v2
8
+ - fill-mask
9
+ - character
10
+ - wwm
11
+ datasets:
12
+ - wikipedia
13
+ - cc100
14
+ - oscar
15
+ metrics:
16
+ - accuracy
17
+ mask_token: "[MASK]"
18
+ widget:
19
+ - text: "京都大学で自然言語処理を[MASK][MASK]する。"
20
+ ---
21
+
22
+ # Model Card for Japanese character-level DeBERTa V2 large
23
+
24
+ ## Model description
25
+
26
+ This is a Japanese DeBERTa V2 large model pre-trained on Japanese Wikipedia, the Japanese portion of CC-100, and the Japanese portion of OSCAR.
27
+ This model is trained with character-level tokenization and whole word masking.
28
+
29
+ ## How to use
30
+
31
+ You can use this model for masked language modeling as follows:
32
+
33
+ ```python
34
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
35
+ tokenizer = AutoTokenizer.from_pretrained('ku-nlp/deberta-v2-large-japanese-char-wwm')
36
+ model = AutoModelForMaskedLM.from_pretrained('ku-nlp/deberta-v2-large-japanese-char-wwm')
37
+
38
+ sentence = '京都大学で自然言語処理を[MASK][MASK]する。'
39
+ encoding = tokenizer(sentence, return_tensors='pt')
40
+ ...
41
+ ```
42
+
43
+ You can also fine-tune this model on downstream tasks.
44
+
45
+ ## Tokenization
46
+
47
+ There is no need to tokenize texts in advance, and you can give raw texts to the tokenizer.
48
+ The texts are tokenized into character-level tokens by [sentencepiece](https://github.com/google/sentencepiece).
49
+
50
+ ## Training data
51
+
52
+ We used the following corpora for pre-training:
53
+
54
+ - Japanese Wikipedia (as of 20221020, 3.2GB, 27M sentences, 1.3M documents)
55
+ - Japanese portion of CC-100 (85GB, 619M sentences, 66M documents)
56
+ - Japanese portion of OSCAR (54GB, 326M sentences, 25M documents)
57
+
58
+ Note that we filtered out documents annotated with "header", "footer", or "noisy" tags in OSCAR.
59
+ Also note that Japanese Wikipedia was duplicated 10 times to make the total size of the corpus comparable to that of CC-100 and OSCAR. As a result, the total size of the training data is 171GB.
60
+
61
+ ## Training procedure
62
+
63
+ We first segmented texts in the corpora into words using [Juman++ 2.0.0-rc3](https://github.com/ku-nlp/jumanpp/releases/tag/v2.0.0-rc3) for whole word masking.
64
+ Then, we built a sentencepiece model with 22,012 tokens including all characters that appear in the training corpus.
65
+
66
+ We tokenized raw corpora into character-level subwords using the sentencepiece model and trained the Japanese DeBERTa model using [transformers](https://github.com/huggingface/transformers) library.
67
+ The training took 26 days using 16 NVIDIA A100-SXM4-40GB GPUs.
68
+
69
+ The following hyperparameters were used during pre-training:
70
+
71
+ - learning_rate: 1e-4
72
+ - per_device_train_batch_size: 26
73
+ - distributed_type: multi-GPU
74
+ - num_devices: 16
75
+ - gradient_accumulation_steps: 8
76
+ - total_train_batch_size: 3,328
77
+ - max_seq_length: 512
78
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-06
79
+ - lr_scheduler_type: linear schedule with warmup (lr = 0 at 300k steps)
80
+ - training_steps: 260,000
81
+ - warmup_steps: 10,000
82
+
83
+ The accuracy of the trained model on the masked language modeling task was 0.795.
84
+ The evaluation set consists of 5,000 randomly sampled documents from each of the training corpora.
85
+
86
+ ## Acknowledgments
87
+
88
+ This work was supported by Joint Usage/Research Center for Interdisciplinary Large-scale Information Infrastructures (JHPCN) through General Collaboration Project no. jh221004, "Developing a Platform for Constructing and Sharing of Large-Scale Japanese Language Models".
89
+ For training models, we used the mdx: a platform for the data-driven future.
bert/deberta-v2-large-japanese-char-wwm/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DebertaV2ForMaskedLM"
4
+ ],
5
+ "attention_head_size": 64,
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "conv_act": "gelu",
8
+ "conv_kernel_size": 3,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 4096,
14
+ "layer_norm_eps": 1e-07,
15
+ "max_position_embeddings": 512,
16
+ "max_relative_positions": -1,
17
+ "model_type": "deberta-v2",
18
+ "norm_rel_ebd": "layer_norm",
19
+ "num_attention_heads": 16,
20
+ "num_hidden_layers": 24,
21
+ "pad_token_id": 0,
22
+ "pooler_dropout": 0,
23
+ "pooler_hidden_act": "gelu",
24
+ "pooler_hidden_size": 1024,
25
+ "pos_att_type": [
26
+ "p2c",
27
+ "c2p"
28
+ ],
29
+ "position_biased_input": false,
30
+ "position_buckets": 256,
31
+ "relative_attention": true,
32
+ "share_att_key": true,
33
+ "torch_dtype": "float16",
34
+ "transformers_version": "4.25.1",
35
+ "type_vocab_size": 0,
36
+ "vocab_size": 22012
37
+ }
bert/deberta-v2-large-japanese-char-wwm/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf0dab8ad87bd7c22e85ec71e04f2240804fda6d33196157d6b5923af6ea1201
3
+ size 1318456639
bert/deberta-v2-large-japanese-char-wwm/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
bert/deberta-v2-large-japanese-char-wwm/tokenizer_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "do_lower_case": false,
4
+ "do_subword_tokenize": true,
5
+ "do_word_tokenize": true,
6
+ "jumanpp_kwargs": null,
7
+ "mask_token": "[MASK]",
8
+ "mecab_kwargs": null,
9
+ "model_max_length": 1000000000000000019884624838656,
10
+ "never_split": null,
11
+ "pad_token": "[PAD]",
12
+ "sep_token": "[SEP]",
13
+ "special_tokens_map_file": null,
14
+ "subword_tokenizer_type": "character",
15
+ "sudachi_kwargs": null,
16
+ "tokenizer_class": "BertJapaneseTokenizer",
17
+ "unk_token": "[UNK]",
18
+ "word_tokenizer_type": "basic"
19
+ }
bert/deberta-v2-large-japanese-char-wwm/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
bert/deberta-v3-large/.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
bert/deberta-v3-large/README.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - deberta
5
+ - deberta-v3
6
+ - fill-mask
7
+ thumbnail: https://huggingface.co/front/thumbnails/microsoft.png
8
+ license: mit
9
+ ---
10
+
11
+ ## DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing
12
+
13
+ [DeBERTa](https://arxiv.org/abs/2006.03654) improves the BERT and RoBERTa models using disentangled attention and enhanced mask decoder. With those two improvements, DeBERTa out perform RoBERTa on a majority of NLU tasks with 80GB training data.
14
+
15
+ In [DeBERTa V3](https://arxiv.org/abs/2111.09543), we further improved the efficiency of DeBERTa using ELECTRA-Style pre-training with Gradient Disentangled Embedding Sharing. Compared to DeBERTa, our V3 version significantly improves the model performance on downstream tasks. You can find more technique details about the new model from our [paper](https://arxiv.org/abs/2111.09543).
16
+
17
+ Please check the [official repository](https://github.com/microsoft/DeBERTa) for more implementation details and updates.
18
+
19
+ The DeBERTa V3 large model comes with 24 layers and a hidden size of 1024. It has 304M backbone parameters with a vocabulary containing 128K tokens which introduces 131M parameters in the Embedding layer. This model was trained using the 160GB data as DeBERTa V2.
20
+
21
+
22
+ #### Fine-tuning on NLU tasks
23
+
24
+ We present the dev results on SQuAD 2.0 and MNLI tasks.
25
+
26
+ | Model |Vocabulary(K)|Backbone #Params(M)| SQuAD 2.0(F1/EM) | MNLI-m/mm(ACC)|
27
+ |-------------------|----------|-------------------|-----------|----------|
28
+ | RoBERTa-large |50 |304 | 89.4/86.5 | 90.2 |
29
+ | XLNet-large |32 |- | 90.6/87.9 | 90.8 |
30
+ | DeBERTa-large |50 |- | 90.7/88.0 | 91.3 |
31
+ | **DeBERTa-v3-large**|128|304 | **91.5/89.0**| **91.8/91.9**|
32
+
33
+
34
+ #### Fine-tuning with HF transformers
35
+
36
+ ```bash
37
+ #!/bin/bash
38
+
39
+ cd transformers/examples/pytorch/text-classification/
40
+
41
+ pip install datasets
42
+ export TASK_NAME=mnli
43
+
44
+ output_dir="ds_results"
45
+
46
+ num_gpus=8
47
+
48
+ batch_size=8
49
+
50
+ python -m torch.distributed.launch --nproc_per_node=${num_gpus} \
51
+ run_glue.py \
52
+ --model_name_or_path microsoft/deberta-v3-large \
53
+ --task_name $TASK_NAME \
54
+ --do_train \
55
+ --do_eval \
56
+ --evaluation_strategy steps \
57
+ --max_seq_length 256 \
58
+ --warmup_steps 50 \
59
+ --per_device_train_batch_size ${batch_size} \
60
+ --learning_rate 6e-6 \
61
+ --num_train_epochs 2 \
62
+ --output_dir $output_dir \
63
+ --overwrite_output_dir \
64
+ --logging_steps 1000 \
65
+ --logging_dir $output_dir
66
+
67
+ ```
68
+
69
+ ### Citation
70
+
71
+ If you find DeBERTa useful for your work, please cite the following papers:
72
+
73
+ ``` latex
74
+ @misc{he2021debertav3,
75
+ title={DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing},
76
+ author={Pengcheng He and Jianfeng Gao and Weizhu Chen},
77
+ year={2021},
78
+ eprint={2111.09543},
79
+ archivePrefix={arXiv},
80
+ primaryClass={cs.CL}
81
+ }
82
+ ```
83
+
84
+ ``` latex
85
+ @inproceedings{
86
+ he2021deberta,
87
+ title={DEBERTA: DECODING-ENHANCED BERT WITH DISENTANGLED ATTENTION},
88
+ author={Pengcheng He and Xiaodong Liu and Jianfeng Gao and Weizhu Chen},
89
+ booktitle={International Conference on Learning Representations},
90
+ year={2021},
91
+ url={https://openreview.net/forum?id=XPZIaotutsD}
92
+ }
93
+ ```
bert/deberta-v3-large/config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "deberta-v2",
3
+ "attention_probs_dropout_prob": 0.1,
4
+ "hidden_act": "gelu",
5
+ "hidden_dropout_prob": 0.1,
6
+ "hidden_size": 1024,
7
+ "initializer_range": 0.02,
8
+ "intermediate_size": 4096,
9
+ "max_position_embeddings": 512,
10
+ "relative_attention": true,
11
+ "position_buckets": 256,
12
+ "norm_rel_ebd": "layer_norm",
13
+ "share_att_key": true,
14
+ "pos_att_type": "p2c|c2p",
15
+ "layer_norm_eps": 1e-7,
16
+ "max_relative_positions": -1,
17
+ "position_biased_input": false,
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "type_vocab_size": 0,
21
+ "vocab_size": 128100
22
+ }
bert/deberta-v3-large/generator_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "deberta-v2",
3
+ "attention_probs_dropout_prob": 0.1,
4
+ "hidden_act": "gelu",
5
+ "hidden_dropout_prob": 0.1,
6
+ "hidden_size": 1024,
7
+ "initializer_range": 0.02,
8
+ "intermediate_size": 4096,
9
+ "max_position_embeddings": 512,
10
+ "relative_attention": true,
11
+ "position_buckets": 256,
12
+ "norm_rel_ebd": "layer_norm",
13
+ "share_att_key": true,
14
+ "pos_att_type": "p2c|c2p",
15
+ "layer_norm_eps": 1e-7,
16
+ "max_relative_positions": -1,
17
+ "position_biased_input": false,
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 12,
20
+ "type_vocab_size": 0,
21
+ "vocab_size": 128100
22
+ }
bert/deberta-v3-large/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd5b5d93e2db101aaf281df0ea1216c07ad73620ff59c5b42dccac4bf2eef5b5
3
+ size 873673253
bert/deberta-v3-large/pytorch_model.bin.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd5b5d93e2db101aaf281df0ea1216c07ad73620ff59c5b42dccac4bf2eef5b5
3
+ size 873673253
bert/deberta-v3-large/spm.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
3
+ size 2464616
bert/deberta-v3-large/tokenizer_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "do_lower_case": false,
3
+ "vocab_type": "spm"
4
+ }
chupa_examples.txt ADDED
The diff for this file is too large to render. See raw diff
 
model_assets/chupa_1/chupa_1spk_e1000_s194312.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8af08fae399f64bbc506a4accaf6c56b0d294def6435235dbe60755728784d8c
3
+ size 251150980
model_assets/chupa_1/config.json ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "chupa_1spk",
3
+ "train": {
4
+ "log_interval": 50,
5
+ "eval_interval": 1000,
6
+ "seed": 42,
7
+ "epochs": 1000,
8
+ "learning_rate": 0.0001,
9
+ "betas": [0.8, 0.99],
10
+ "eps": 1e-9,
11
+ "batch_size": 2,
12
+ "bf16_run": false,
13
+ "fp16_run": false,
14
+ "lr_decay": 0.99996,
15
+ "segment_size": 16384,
16
+ "init_lr_ratio": 1,
17
+ "warmup_epochs": 0,
18
+ "c_mel": 45,
19
+ "c_kl": 1.0,
20
+ "c_commit": 100,
21
+ "skip_optimizer": false,
22
+ "freeze_ZH_bert": false,
23
+ "freeze_JP_bert": false,
24
+ "freeze_EN_bert": false,
25
+ "freeze_emo": false,
26
+ "freeze_style": false,
27
+ "freeze_decoder": false
28
+ },
29
+ "data": {
30
+ "use_jp_extra": true,
31
+ "training_files": "Data/chupa_1/train.list",
32
+ "validation_files": "Data/chupa_1/val.list",
33
+ "max_wav_value": 32768.0,
34
+ "sampling_rate": 44100,
35
+ "filter_length": 2048,
36
+ "hop_length": 512,
37
+ "win_length": 2048,
38
+ "n_mel_channels": 128,
39
+ "mel_fmin": 0.0,
40
+ "mel_fmax": null,
41
+ "add_blank": true,
42
+ "n_speakers": 1,
43
+ "spk2id": {
44
+ "1": 0
45
+ },
46
+ "cleaned_text": true,
47
+ "num_styles": 1,
48
+ "style2id": {
49
+ "Neutral": 0
50
+ }
51
+ },
52
+ "model": {
53
+ "use_spk_conditioned_encoder": true,
54
+ "use_noise_scaled_mas": true,
55
+ "use_mel_posterior_encoder": false,
56
+ "use_duration_discriminator": false,
57
+ "use_wavlm_discriminator": true,
58
+ "inter_channels": 192,
59
+ "hidden_channels": 192,
60
+ "filter_channels": 768,
61
+ "n_heads": 2,
62
+ "n_layers": 6,
63
+ "kernel_size": 3,
64
+ "p_dropout": 0.1,
65
+ "resblock": "1",
66
+ "resblock_kernel_sizes": [3, 7, 11],
67
+ "resblock_dilation_sizes": [
68
+ [1, 3, 5],
69
+ [1, 3, 5],
70
+ [1, 3, 5]
71
+ ],
72
+ "upsample_rates": [8, 8, 2, 2, 2],
73
+ "upsample_initial_channel": 512,
74
+ "upsample_kernel_sizes": [16, 16, 8, 2, 2],
75
+ "n_layers_q": 3,
76
+ "use_spectral_norm": false,
77
+ "gin_channels": 512,
78
+ "slm": {
79
+ "model": "./slm/wavlm-base-plus",
80
+ "sr": 16000,
81
+ "hidden": 768,
82
+ "nlayers": 13,
83
+ "initial_channel": 64
84
+ }
85
+ },
86
+ "version": "2.6.0-JP-Extra"
87
+ }
model_assets/chupa_1/style_vectors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fd42ba186b887c87b57fa66f5781f3fdf4382504d971d5338288d50b8b40461
3
+ size 1152
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmudict
2
+ cn2an
3
+ # faster-whisper==0.10.1
4
+ g2p_en
5
+ GPUtil
6
+ gradio
7
+ jieba
8
+ # librosa==0.9.2
9
+ loguru
10
+ num2words
11
+ numpy<2
12
+ # protobuf==4.25
13
+ psutil
14
+ # punctuators
15
+ pyannote.audio>=3.1.0
16
+ # pyloudnorm
17
+ pyopenjtalk-dict
18
+ pypinyin
19
+ pyworld-prebuilt
20
+ # stable_ts
21
+ # tensorboard
22
+ torch
23
+ transformers
style_bert_vits2/.editorconfig ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ root = true
2
+
3
+ [*]
4
+ charset = utf-8
5
+ end_of_line = lf
6
+ insert_final_newline = true
7
+ indent_size = 4
8
+ indent_style = space
9
+ trim_trailing_whitespace = true
10
+
11
+ [*.md]
12
+ trim_trailing_whitespace = false
13
+
14
+ [*.yml]
15
+ indent_size = 2
style_bert_vits2/__init__.py ADDED
File without changes
style_bert_vits2/constants.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from style_bert_vits2.utils.strenum import StrEnum
4
+
5
+
6
+ # Style-Bert-VITS2 のバージョン
7
+ VERSION = "2.6.0"
8
+
9
+ # Style-Bert-VITS2 のベースディレクトリ
10
+ BASE_DIR = Path(__file__).parent.parent
11
+
12
+
13
+ # 利用可能な言語
14
+ ## JP-Extra モデル利用時は JP 以外の言語の音声合成はできない
15
+ class Languages(StrEnum):
16
+ JP = "JP"
17
+ EN = "EN"
18
+ ZH = "ZH"
19
+
20
+
21
+ # 言語ごとのデフォルトの BERT トークナイザーのパス
22
+ DEFAULT_BERT_TOKENIZER_PATHS = {
23
+ Languages.JP: BASE_DIR / "bert" / "deberta-v2-large-japanese-char-wwm",
24
+ Languages.EN: BASE_DIR / "bert" / "deberta-v3-large",
25
+ Languages.ZH: BASE_DIR / "bert" / "chinese-roberta-wwm-ext-large",
26
+ }
27
+
28
+ # デフォルトのユーザー辞書ディレクトリ
29
+ ## style_bert_vits2.nlp.japanese.user_dict モジュールのデフォルト値として利用される
30
+ ## ライブラリとしての利用などで外部のユーザー辞書を指定したい場合は、user_dict 以下の各関数の実行時、引数に辞書データファイルのパスを指定する
31
+ DEFAULT_USER_DICT_DIR = BASE_DIR / "dict_data"
32
+
33
+ # デフォルトの推論パラメータ
34
+ DEFAULT_STYLE = "Neutral"
35
+ DEFAULT_STYLE_WEIGHT = 1.0
36
+ DEFAULT_SDP_RATIO = 0.2
37
+ DEFAULT_NOISE = 0.6
38
+ DEFAULT_NOISEW = 0.8
39
+ DEFAULT_LENGTH = 1.0
40
+ DEFAULT_LINE_SPLIT = True
41
+ DEFAULT_SPLIT_INTERVAL = 0.5
42
+ DEFAULT_ASSIST_TEXT_WEIGHT = 0.7
43
+ DEFAULT_ASSIST_TEXT_WEIGHT = 1.0
44
+
45
+ # Gradio のテーマ
46
+ ## Built-in theme: "default", "base", "monochrome", "soft", "glass"
47
+ ## See https://huggingface.co/spaces/gradio/theme-gallery for more themes
48
+ GRADIO_THEME = "NoCrypt/miku"
style_bert_vits2/logging.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+
3
+ from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT
4
+
5
+
6
+ # Remove all default handlers
7
+ logger.remove()
8
+
9
+ # Add a new handler
10
+ logger.add(
11
+ SAFE_STDOUT,
12
+ format="<g>{time:MM-DD HH:mm:ss}</g> |<lvl>{level:^8}</lvl>| {file}:{line} | {message}",
13
+ backtrace=True,
14
+ diagnose=True,
15
+ )
style_bert_vits2/models/__init__.py ADDED
File without changes
style_bert_vits2/models/attentions.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from style_bert_vits2.models import commons
9
+
10
+
11
+ class LayerNorm(nn.Module):
12
+ def __init__(self, channels: int, eps: float = 1e-5) -> None:
13
+ super().__init__()
14
+ self.channels = channels
15
+ self.eps = eps
16
+
17
+ self.gamma = nn.Parameter(torch.ones(channels))
18
+ self.beta = nn.Parameter(torch.zeros(channels))
19
+
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ x = x.transpose(1, -1)
22
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
23
+ return x.transpose(1, -1)
24
+
25
+
26
+ @torch.jit.script # type: ignore
27
+ def fused_add_tanh_sigmoid_multiply(
28
+ input_a: torch.Tensor, input_b: torch.Tensor, n_channels: list[int]
29
+ ) -> torch.Tensor:
30
+ n_channels_int = n_channels[0]
31
+ in_act = input_a + input_b
32
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
33
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
34
+ acts = t_act * s_act
35
+ return acts
36
+
37
+
38
+ class Encoder(nn.Module):
39
+ def __init__(
40
+ self,
41
+ hidden_channels: int,
42
+ filter_channels: int,
43
+ n_heads: int,
44
+ n_layers: int,
45
+ kernel_size: int = 1,
46
+ p_dropout: float = 0.0,
47
+ window_size: int = 4,
48
+ isflow: bool = True,
49
+ **kwargs: Any,
50
+ ) -> None:
51
+ super().__init__()
52
+ self.hidden_channels = hidden_channels
53
+ self.filter_channels = filter_channels
54
+ self.n_heads = n_heads
55
+ self.n_layers = n_layers
56
+ self.kernel_size = kernel_size
57
+ self.p_dropout = p_dropout
58
+ self.window_size = window_size
59
+ # if isflow:
60
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
61
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
62
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
63
+ # self.gin_channels = 256
64
+ self.cond_layer_idx = self.n_layers
65
+ if "gin_channels" in kwargs:
66
+ self.gin_channels = kwargs["gin_channels"]
67
+ if self.gin_channels != 0:
68
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
69
+ # vits2 says 3rd block, so idx is 2 by default
70
+ self.cond_layer_idx = (
71
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
72
+ )
73
+ # logger.debug(self.gin_channels, self.cond_layer_idx)
74
+ assert (
75
+ self.cond_layer_idx < self.n_layers
76
+ ), "cond_layer_idx should be less than n_layers"
77
+ self.drop = nn.Dropout(p_dropout)
78
+ self.attn_layers = nn.ModuleList()
79
+ self.norm_layers_1 = nn.ModuleList()
80
+ self.ffn_layers = nn.ModuleList()
81
+ self.norm_layers_2 = nn.ModuleList()
82
+ for i in range(self.n_layers):
83
+ self.attn_layers.append(
84
+ MultiHeadAttention(
85
+ hidden_channels,
86
+ hidden_channels,
87
+ n_heads,
88
+ p_dropout=p_dropout,
89
+ window_size=window_size,
90
+ )
91
+ )
92
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
93
+ self.ffn_layers.append(
94
+ FFN(
95
+ hidden_channels,
96
+ hidden_channels,
97
+ filter_channels,
98
+ kernel_size,
99
+ p_dropout=p_dropout,
100
+ )
101
+ )
102
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
103
+
104
+ def forward(
105
+ self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
106
+ ) -> torch.Tensor:
107
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
108
+ x = x * x_mask
109
+ for i in range(self.n_layers):
110
+ if i == self.cond_layer_idx and g is not None:
111
+ g = self.spk_emb_linear(g.transpose(1, 2))
112
+ assert g is not None
113
+ g = g.transpose(1, 2)
114
+ x = x + g
115
+ x = x * x_mask
116
+ y = self.attn_layers[i](x, x, attn_mask)
117
+ y = self.drop(y)
118
+ x = self.norm_layers_1[i](x + y)
119
+
120
+ y = self.ffn_layers[i](x, x_mask)
121
+ y = self.drop(y)
122
+ x = self.norm_layers_2[i](x + y)
123
+ x = x * x_mask
124
+ return x
125
+
126
+
127
+ class Decoder(nn.Module):
128
+ def __init__(
129
+ self,
130
+ hidden_channels: int,
131
+ filter_channels: int,
132
+ n_heads: int,
133
+ n_layers: int,
134
+ kernel_size: int = 1,
135
+ p_dropout: float = 0.0,
136
+ proximal_bias: bool = False,
137
+ proximal_init: bool = True,
138
+ **kwargs: Any,
139
+ ) -> None:
140
+ super().__init__()
141
+ self.hidden_channels = hidden_channels
142
+ self.filter_channels = filter_channels
143
+ self.n_heads = n_heads
144
+ self.n_layers = n_layers
145
+ self.kernel_size = kernel_size
146
+ self.p_dropout = p_dropout
147
+ self.proximal_bias = proximal_bias
148
+ self.proximal_init = proximal_init
149
+
150
+ self.drop = nn.Dropout(p_dropout)
151
+ self.self_attn_layers = nn.ModuleList()
152
+ self.norm_layers_0 = nn.ModuleList()
153
+ self.encdec_attn_layers = nn.ModuleList()
154
+ self.norm_layers_1 = nn.ModuleList()
155
+ self.ffn_layers = nn.ModuleList()
156
+ self.norm_layers_2 = nn.ModuleList()
157
+ for i in range(self.n_layers):
158
+ self.self_attn_layers.append(
159
+ MultiHeadAttention(
160
+ hidden_channels,
161
+ hidden_channels,
162
+ n_heads,
163
+ p_dropout=p_dropout,
164
+ proximal_bias=proximal_bias,
165
+ proximal_init=proximal_init,
166
+ )
167
+ )
168
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
169
+ self.encdec_attn_layers.append(
170
+ MultiHeadAttention(
171
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
172
+ )
173
+ )
174
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
175
+ self.ffn_layers.append(
176
+ FFN(
177
+ hidden_channels,
178
+ hidden_channels,
179
+ filter_channels,
180
+ kernel_size,
181
+ p_dropout=p_dropout,
182
+ causal=True,
183
+ )
184
+ )
185
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
186
+
187
+ def forward(
188
+ self,
189
+ x: torch.Tensor,
190
+ x_mask: torch.Tensor,
191
+ h: torch.Tensor,
192
+ h_mask: torch.Tensor,
193
+ ) -> torch.Tensor:
194
+ """
195
+ x: decoder input
196
+ h: encoder output
197
+ """
198
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
199
+ device=x.device, dtype=x.dtype
200
+ )
201
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
202
+ x = x * x_mask
203
+ for i in range(self.n_layers):
204
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
205
+ y = self.drop(y)
206
+ x = self.norm_layers_0[i](x + y)
207
+
208
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
209
+ y = self.drop(y)
210
+ x = self.norm_layers_1[i](x + y)
211
+
212
+ y = self.ffn_layers[i](x, x_mask)
213
+ y = self.drop(y)
214
+ x = self.norm_layers_2[i](x + y)
215
+ x = x * x_mask
216
+ return x
217
+
218
+
219
+ class MultiHeadAttention(nn.Module):
220
+ def __init__(
221
+ self,
222
+ channels: int,
223
+ out_channels: int,
224
+ n_heads: int,
225
+ p_dropout: float = 0.0,
226
+ window_size: Optional[int] = None,
227
+ heads_share: bool = True,
228
+ block_length: Optional[int] = None,
229
+ proximal_bias: bool = False,
230
+ proximal_init: bool = False,
231
+ ) -> None:
232
+ super().__init__()
233
+ assert channels % n_heads == 0
234
+
235
+ self.channels = channels
236
+ self.out_channels = out_channels
237
+ self.n_heads = n_heads
238
+ self.p_dropout = p_dropout
239
+ self.window_size = window_size
240
+ self.heads_share = heads_share
241
+ self.block_length = block_length
242
+ self.proximal_bias = proximal_bias
243
+ self.proximal_init = proximal_init
244
+ self.attn = None
245
+
246
+ self.k_channels = channels // n_heads
247
+ self.conv_q = nn.Conv1d(channels, channels, 1)
248
+ self.conv_k = nn.Conv1d(channels, channels, 1)
249
+ self.conv_v = nn.Conv1d(channels, channels, 1)
250
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
251
+ self.drop = nn.Dropout(p_dropout)
252
+
253
+ if window_size is not None:
254
+ n_heads_rel = 1 if heads_share else n_heads
255
+ rel_stddev = self.k_channels**-0.5
256
+ self.emb_rel_k = nn.Parameter(
257
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
258
+ * rel_stddev
259
+ )
260
+ self.emb_rel_v = nn.Parameter(
261
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
262
+ * rel_stddev
263
+ )
264
+
265
+ nn.init.xavier_uniform_(self.conv_q.weight)
266
+ nn.init.xavier_uniform_(self.conv_k.weight)
267
+ nn.init.xavier_uniform_(self.conv_v.weight)
268
+ if proximal_init:
269
+ with torch.no_grad():
270
+ self.conv_k.weight.copy_(self.conv_q.weight)
271
+ assert self.conv_k.bias is not None
272
+ assert self.conv_q.bias is not None
273
+ self.conv_k.bias.copy_(self.conv_q.bias)
274
+
275
+ def forward(
276
+ self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
277
+ ) -> torch.Tensor:
278
+ q = self.conv_q(x)
279
+ k = self.conv_k(c)
280
+ v = self.conv_v(c)
281
+
282
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
283
+
284
+ x = self.conv_o(x)
285
+ return x
286
+
287
+ def attention(
288
+ self,
289
+ query: torch.Tensor,
290
+ key: torch.Tensor,
291
+ value: torch.Tensor,
292
+ mask: Optional[torch.Tensor] = None,
293
+ ) -> tuple[torch.Tensor, torch.Tensor]:
294
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
295
+ b, d, t_s, t_t = (*key.size(), query.size(2))
296
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
297
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
298
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
299
+
300
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
301
+ if self.window_size is not None:
302
+ assert (
303
+ t_s == t_t
304
+ ), "Relative attention is only available for self-attention."
305
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
306
+ rel_logits = self._matmul_with_relative_keys(
307
+ query / math.sqrt(self.k_channels), key_relative_embeddings
308
+ )
309
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
310
+ scores = scores + scores_local
311
+ if self.proximal_bias:
312
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
313
+ scores = scores + self._attention_bias_proximal(t_s).to(
314
+ device=scores.device, dtype=scores.dtype
315
+ )
316
+ if mask is not None:
317
+ scores = scores.masked_fill(mask == 0, -1e4)
318
+ if self.block_length is not None:
319
+ assert (
320
+ t_s == t_t
321
+ ), "Local attention is only available for self-attention."
322
+ block_mask = (
323
+ torch.ones_like(scores)
324
+ .triu(-self.block_length)
325
+ .tril(self.block_length)
326
+ )
327
+ scores = scores.masked_fill(block_mask == 0, -1e4)
328
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
329
+ p_attn = self.drop(p_attn)
330
+ output = torch.matmul(p_attn, value)
331
+ if self.window_size is not None:
332
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
333
+ value_relative_embeddings = self._get_relative_embeddings(
334
+ self.emb_rel_v, t_s
335
+ )
336
+ output = output + self._matmul_with_relative_values(
337
+ relative_weights, value_relative_embeddings
338
+ )
339
+ output = (
340
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
341
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
342
+ return output, p_attn
343
+
344
+ def _matmul_with_relative_values(
345
+ self, x: torch.Tensor, y: torch.Tensor
346
+ ) -> torch.Tensor:
347
+ """
348
+ x: [b, h, l, m]
349
+ y: [h or 1, m, d]
350
+ ret: [b, h, l, d]
351
+ """
352
+ ret = torch.matmul(x, y.unsqueeze(0))
353
+ return ret
354
+
355
+ def _matmul_with_relative_keys(
356
+ self, x: torch.Tensor, y: torch.Tensor
357
+ ) -> torch.Tensor:
358
+ """
359
+ x: [b, h, l, d]
360
+ y: [h or 1, m, d]
361
+ ret: [b, h, l, m]
362
+ """
363
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
364
+ return ret
365
+
366
+ def _get_relative_embeddings(
367
+ self, relative_embeddings: torch.Tensor, length: int
368
+ ) -> torch.Tensor:
369
+ assert self.window_size is not None
370
+ 2 * self.window_size + 1 # type: ignore
371
+ # Pad first before slice to avoid using cond ops.
372
+ pad_length = max(length - (self.window_size + 1), 0)
373
+ slice_start_position = max((self.window_size + 1) - length, 0)
374
+ slice_end_position = slice_start_position + 2 * length - 1
375
+ if pad_length > 0:
376
+ padded_relative_embeddings = F.pad(
377
+ relative_embeddings,
378
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
379
+ )
380
+ else:
381
+ padded_relative_embeddings = relative_embeddings
382
+ used_relative_embeddings = padded_relative_embeddings[
383
+ :, slice_start_position:slice_end_position
384
+ ]
385
+ return used_relative_embeddings
386
+
387
+ def _relative_position_to_absolute_position(self, x: torch.Tensor) -> torch.Tensor:
388
+ """
389
+ x: [b, h, l, 2*l-1]
390
+ ret: [b, h, l, l]
391
+ """
392
+ batch, heads, length, _ = x.size()
393
+ # Concat columns of pad to shift from relative to absolute indexing.
394
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
395
+
396
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
397
+ x_flat = x.view([batch, heads, length * 2 * length])
398
+ x_flat = F.pad(
399
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
400
+ )
401
+
402
+ # Reshape and slice out the padded elements.
403
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
404
+ :, :, :length, length - 1 :
405
+ ]
406
+ return x_final
407
+
408
+ def _absolute_position_to_relative_position(self, x: torch.Tensor) -> torch.Tensor:
409
+ """
410
+ x: [b, h, l, l]
411
+ ret: [b, h, l, 2*l-1]
412
+ """
413
+ batch, heads, length, _ = x.size()
414
+ # pad along column
415
+ x = F.pad(
416
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
417
+ )
418
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
419
+ # add 0's in the beginning that will skew the elements after reshape
420
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
421
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
422
+ return x_final
423
+
424
+ def _attention_bias_proximal(self, length: int) -> torch.Tensor:
425
+ """Bias for self-attention to encourage attention to close positions.
426
+ Args:
427
+ length: an integer scalar.
428
+ Returns:
429
+ a Tensor with shape [1, 1, length, length]
430
+ """
431
+ r = torch.arange(length, dtype=torch.float32)
432
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
433
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
434
+
435
+
436
+ class FFN(nn.Module):
437
+ def __init__(
438
+ self,
439
+ in_channels: int,
440
+ out_channels: int,
441
+ filter_channels: int,
442
+ kernel_size: int,
443
+ p_dropout: float = 0.0,
444
+ activation: Optional[str] = None,
445
+ causal: bool = False,
446
+ ) -> None:
447
+ super().__init__()
448
+ self.in_channels = in_channels
449
+ self.out_channels = out_channels
450
+ self.filter_channels = filter_channels
451
+ self.kernel_size = kernel_size
452
+ self.p_dropout = p_dropout
453
+ self.activation = activation
454
+ self.causal = causal
455
+
456
+ if causal:
457
+ self.padding = self._causal_padding
458
+ else:
459
+ self.padding = self._same_padding
460
+
461
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
462
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
463
+ self.drop = nn.Dropout(p_dropout)
464
+
465
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
466
+ x = self.conv_1(self.padding(x * x_mask))
467
+ if self.activation == "gelu":
468
+ x = x * torch.sigmoid(1.702 * x)
469
+ else:
470
+ x = torch.relu(x)
471
+ x = self.drop(x)
472
+ x = self.conv_2(self.padding(x * x_mask))
473
+ return x * x_mask
474
+
475
+ def _causal_padding(self, x: torch.Tensor) -> torch.Tensor:
476
+ if self.kernel_size == 1:
477
+ return x
478
+ pad_l = self.kernel_size - 1
479
+ pad_r = 0
480
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
481
+ x = F.pad(x, commons.convert_pad_shape(padding))
482
+ return x
483
+
484
+ def _same_padding(self, x: torch.Tensor) -> torch.Tensor:
485
+ if self.kernel_size == 1:
486
+ return x
487
+ pad_l = (self.kernel_size - 1) // 2
488
+ pad_r = self.kernel_size // 2
489
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
490
+ x = F.pad(x, commons.convert_pad_shape(padding))
491
+ return x
style_bert_vits2/models/commons.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 以下に記述されている関数のコメントはリファクタリング時に GPT-4 に生成させたもので、
3
+ コードと完全に一致している保証はない。あくまで参考程度とすること。
4
+ """
5
+
6
+ from typing import Any, Optional, Union
7
+
8
+ import torch
9
+ from torch.nn import functional as F
10
+
11
+
12
+ def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None:
13
+ """
14
+ モジュールの重みを初期化する
15
+
16
+ Args:
17
+ m (torch.nn.Module): 重みを初期化する対象のモジュール
18
+ mean (float): 正規分布の平均
19
+ std (float): 正規分布の標準偏差
20
+ """
21
+ classname = m.__class__.__name__
22
+ if classname.find("Conv") != -1:
23
+ m.weight.data.normal_(mean, std)
24
+
25
+
26
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
27
+ """
28
+ カーネルサイズと膨張率からパディングの大きさを計算する
29
+
30
+ Args:
31
+ kernel_size (int): カーネルのサイズ
32
+ dilation (int): 膨張率
33
+
34
+ Returns:
35
+ int: 計算されたパディングの大きさ
36
+ """
37
+ return int((kernel_size * dilation - dilation) / 2)
38
+
39
+
40
+ def convert_pad_shape(pad_shape: list[list[Any]]) -> list[Any]:
41
+ """
42
+ パディングの形状を変換する
43
+
44
+ Args:
45
+ pad_shape (list[list[Any]]): 変換前のパディングの形状
46
+
47
+ Returns:
48
+ list[Any]: 変換後のパディングの形状
49
+ """
50
+ layer = pad_shape[::-1]
51
+ new_pad_shape = [item for sublist in layer for item in sublist]
52
+ return new_pad_shape
53
+
54
+
55
+ def intersperse(lst: list[Any], item: Any) -> list[Any]:
56
+ """
57
+ リストの要素の間に特定のアイテムを挿入する
58
+
59
+ Args:
60
+ lst (list[Any]): 元のリスト
61
+ item (Any): 挿入するアイテム
62
+
63
+ Returns:
64
+ list[Any]: 新しいリスト
65
+ """
66
+ result = [item] * (len(lst) * 2 + 1)
67
+ result[1::2] = lst
68
+ return result
69
+
70
+
71
+ def slice_segments(
72
+ x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4
73
+ ) -> torch.Tensor:
74
+ """
75
+ テンソルからセグメントをスライスする
76
+
77
+ Args:
78
+ x (torch.Tensor): 入力テンソル
79
+ ids_str (torch.Tensor): スライスを開始するインデックス
80
+ segment_size (int, optional): スライスのサイズ (デフォルト: 4)
81
+
82
+ Returns:
83
+ torch.Tensor: スライスされたセグメント
84
+ """
85
+ gather_indices = ids_str.view(x.size(0), 1, 1).repeat(
86
+ 1, x.size(1), 1
87
+ ) + torch.arange(segment_size, device=x.device)
88
+ return torch.gather(x, 2, gather_indices)
89
+
90
+
91
+ def rand_slice_segments(
92
+ x: torch.Tensor, x_lengths: Optional[torch.Tensor] = None, segment_size: int = 4
93
+ ) -> tuple[torch.Tensor, torch.Tensor]:
94
+ """
95
+ ランダムなセグメントをスライスする
96
+
97
+ Args:
98
+ x (torch.Tensor): 入力テンソル
99
+ x_lengths (Optional[torch.Tensor], optional): 各バッチの長さ (デフォルト: None)
100
+ segment_size (int, optional): スライスのサイズ (デフォルト: 4)
101
+
102
+ Returns:
103
+ tuple[torch.Tensor, torch.Tensor]: スライスされたセグメントと開始インデックス
104
+ """
105
+ b, d, t = x.size()
106
+ if x_lengths is None:
107
+ x_lengths = t # type: ignore
108
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0) # type: ignore
109
+ ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
110
+ ret = slice_segments(x, ids_str, segment_size)
111
+ return ret, ids_str
112
+
113
+
114
+ def subsequent_mask(length: int) -> torch.Tensor:
115
+ """
116
+ 後続のマスクを生成する
117
+
118
+ Args:
119
+ length (int): マスクのサイズ
120
+
121
+ Returns:
122
+ torch.Tensor: 生成されたマスク
123
+ """
124
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
125
+ return mask
126
+
127
+
128
+ @torch.jit.script # type: ignore
129
+ def fused_add_tanh_sigmoid_multiply(
130
+ input_a: torch.Tensor, input_b: torch.Tensor, n_channels: torch.Tensor
131
+ ) -> torch.Tensor:
132
+ """
133
+ 加算、tanh、sigmoid の活性化関数を組み合わせた演算を行う
134
+
135
+ Args:
136
+ input_a (torch.Tensor): 入力テンソル A
137
+ input_b (torch.Tensor): 入力テンソル B
138
+ n_channels (torch.Tensor): チャネル数
139
+
140
+ Returns:
141
+ torch.Tensor: 演算結果
142
+ """
143
+ n_channels_int = n_channels[0]
144
+ in_act = input_a + input_b
145
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
146
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
147
+ acts = t_act * s_act
148
+ return acts
149
+
150
+
151
+ def sequence_mask(
152
+ length: torch.Tensor, max_length: Optional[int] = None
153
+ ) -> torch.Tensor:
154
+ """
155
+ シーケンスマスクを生成する
156
+
157
+ Args:
158
+ length (torch.Tensor): 各シーケンスの長さ
159
+ max_length (Optional[int]): 最大のシーケンス長さ。指定されていない場合は length の最大値を使用
160
+
161
+ Returns:
162
+ torch.Tensor: 生成されたシーケンスマスク
163
+ """
164
+ if max_length is None:
165
+ max_length = length.max() # type: ignore
166
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device) # type: ignore
167
+ return x.unsqueeze(0) < length.unsqueeze(1)
168
+
169
+
170
+ def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
171
+ """
172
+ パスを生成する
173
+
174
+ Args:
175
+ duration (torch.Tensor): 各時間ステップの持続時間
176
+ mask (torch.Tensor): マスクテンソル
177
+
178
+ Returns:
179
+ torch.Tensor: 生成されたパス
180
+ """
181
+ b, _, t_y, t_x = mask.shape
182
+ cum_duration = torch.cumsum(duration, -1)
183
+
184
+ cum_duration_flat = cum_duration.view(b * t_x)
185
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
186
+ path = path.view(b, t_x, t_y)
187
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
188
+ path = path.unsqueeze(1).transpose(2, 3) * mask
189
+ return path
190
+
191
+
192
+ def clip_grad_value_(
193
+ parameters: Union[torch.Tensor, list[torch.Tensor]],
194
+ clip_value: Optional[float],
195
+ norm_type: float = 2.0,
196
+ ) -> float:
197
+ """
198
+ 勾配の値をクリップする
199
+
200
+ Args:
201
+ parameters (Union[torch.Tensor, list[torch.Tensor]]): クリップするパラメータ
202
+ clip_value (Optional[float]): クリップする値。None の場合はクリップしない
203
+ norm_type (float): ノルムの種類
204
+
205
+ Returns:
206
+ float: 総ノルム
207
+ """
208
+ if isinstance(parameters, torch.Tensor):
209
+ parameters = [parameters]
210
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
211
+ norm_type = float(norm_type)
212
+ if clip_value is not None:
213
+ clip_value = float(clip_value)
214
+
215
+ total_norm = 0.0
216
+ for p in parameters:
217
+ assert p.grad is not None
218
+ param_norm = p.grad.data.norm(norm_type)
219
+ total_norm += param_norm.item() ** norm_type
220
+ if clip_value is not None:
221
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
222
+ total_norm = total_norm ** (1.0 / norm_type)
223
+ return total_norm
style_bert_vits2/models/hyper_parameters.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Style-Bert-VITS2 モデルのハイパーパラメータを表す Pydantic モデル。
3
+ デフォルト値は configs/config_jp_extra.json 内の定義と概ね同一で、
4
+ 万が一ロードした config.json に存在しないキーがあった際のフェイルセーフとして適用される。
5
+ """
6
+
7
+ from pathlib import Path
8
+ from typing import Optional, Union
9
+
10
+ from pydantic import BaseModel, ConfigDict
11
+
12
+
13
+ class HyperParametersTrain(BaseModel):
14
+ log_interval: int = 200
15
+ eval_interval: int = 1000
16
+ seed: int = 42
17
+ epochs: int = 1000
18
+ learning_rate: float = 0.0001
19
+ betas: tuple[float, float] = (0.8, 0.99)
20
+ eps: float = 1e-9
21
+ batch_size: int = 2
22
+ bf16_run: bool = False
23
+ fp16_run: bool = False
24
+ lr_decay: float = 0.99996
25
+ segment_size: int = 16384
26
+ init_lr_ratio: int = 1
27
+ warmup_epochs: int = 0
28
+ c_mel: int = 45
29
+ c_kl: float = 1.0
30
+ c_commit: int = 100
31
+ skip_optimizer: bool = False
32
+ freeze_ZH_bert: bool = False
33
+ freeze_JP_bert: bool = False
34
+ freeze_EN_bert: bool = False
35
+ freeze_emo: bool = False
36
+ freeze_style: bool = False
37
+ freeze_decoder: bool = False
38
+
39
+
40
+ class HyperParametersData(BaseModel):
41
+ use_jp_extra: bool = True
42
+ training_files: str = "Data/Dummy/train.list"
43
+ validation_files: str = "Data/Dummy/val.list"
44
+ max_wav_value: float = 32768.0
45
+ sampling_rate: int = 44100
46
+ filter_length: int = 2048
47
+ hop_length: int = 512
48
+ win_length: int = 2048
49
+ n_mel_channels: int = 128
50
+ mel_fmin: float = 0.0
51
+ mel_fmax: Optional[float] = None
52
+ add_blank: bool = True
53
+ n_speakers: int = 1
54
+ cleaned_text: bool = True
55
+ spk2id: dict[str, int] = {
56
+ "Dummy": 0,
57
+ }
58
+ num_styles: int = 1
59
+ style2id: dict[str, int] = {
60
+ "Neutral": 0,
61
+ }
62
+
63
+
64
+ class HyperParametersModelSLM(BaseModel):
65
+ model: str = "./slm/wavlm-base-plus"
66
+ sr: int = 16000
67
+ hidden: int = 768
68
+ nlayers: int = 13
69
+ initial_channel: int = 64
70
+
71
+
72
+ class HyperParametersModel(BaseModel):
73
+ use_spk_conditioned_encoder: bool = True
74
+ use_noise_scaled_mas: bool = True
75
+ use_mel_posterior_encoder: bool = False
76
+ use_duration_discriminator: bool = False
77
+ use_wavlm_discriminator: bool = True
78
+ inter_channels: int = 192
79
+ hidden_channels: int = 192
80
+ filter_channels: int = 768
81
+ n_heads: int = 2
82
+ n_layers: int = 6
83
+ kernel_size: int = 3
84
+ p_dropout: float = 0.1
85
+ resblock: str = "1"
86
+ resblock_kernel_sizes: list[int] = [3, 7, 11]
87
+ resblock_dilation_sizes: list[list[int]] = [
88
+ [1, 3, 5],
89
+ [1, 3, 5],
90
+ [1, 3, 5],
91
+ ]
92
+ upsample_rates: list[int] = [8, 8, 2, 2, 2]
93
+ upsample_initial_channel: int = 512
94
+ upsample_kernel_sizes: list[int] = [16, 16, 8, 2, 2]
95
+ n_layers_q: int = 3
96
+ use_spectral_norm: bool = False
97
+ gin_channels: int = 512
98
+ slm: HyperParametersModelSLM = HyperParametersModelSLM()
99
+
100
+
101
+ class HyperParameters(BaseModel):
102
+ model_name: str = "Dummy"
103
+ version: str = "2.0-JP-Extra"
104
+ train: HyperParametersTrain = HyperParametersTrain()
105
+ data: HyperParametersData = HyperParametersData()
106
+ model: HyperParametersModel = HyperParametersModel()
107
+
108
+ # 以下は学習時にのみ動的に設定されるパラメータ (通常 config.json には存在しない)
109
+ model_dir: Optional[str] = None
110
+ speedup: bool = False
111
+ repo_id: Optional[str] = None
112
+
113
+ # model_ 以下を Pydantic の保護対象から除外する
114
+ model_config = ConfigDict(protected_namespaces=())
115
+
116
+ @staticmethod
117
+ def load_from_json(json_path: Union[str, Path]) -> "HyperParameters":
118
+ """
119
+ 与えられた JSON ファイルからハイパーパラメータを読み込む。
120
+
121
+ Args:
122
+ json_path (Union[str, Path]): JSON ファイルのパス
123
+
124
+ Returns:
125
+ HyperParameters: ハイパーパラメータ
126
+ """
127
+
128
+ with open(json_path, encoding="utf-8") as f:
129
+ return HyperParameters.model_validate_json(f.read())
style_bert_vits2/models/infer.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union, cast
2
+
3
+ import torch
4
+ from numpy.typing import NDArray
5
+
6
+ from style_bert_vits2.constants import Languages
7
+ from style_bert_vits2.logging import logger
8
+ from style_bert_vits2.models import commons, utils
9
+ from style_bert_vits2.models.hyper_parameters import HyperParameters
10
+ from style_bert_vits2.models.models import SynthesizerTrn
11
+ from style_bert_vits2.models.models_jp_extra import (
12
+ SynthesizerTrn as SynthesizerTrnJPExtra,
13
+ )
14
+ from style_bert_vits2.nlp import (
15
+ clean_text,
16
+ cleaned_text_to_sequence,
17
+ extract_bert_feature,
18
+ )
19
+ from style_bert_vits2.nlp.symbols import SYMBOLS
20
+
21
+
22
+ def get_net_g(model_path: str, version: str, device: str, hps: HyperParameters):
23
+ if version.endswith("JP-Extra"):
24
+ logger.info("Using JP-Extra model")
25
+ net_g = SynthesizerTrnJPExtra(
26
+ n_vocab=len(SYMBOLS),
27
+ spec_channels=hps.data.filter_length // 2 + 1,
28
+ segment_size=hps.train.segment_size // hps.data.hop_length,
29
+ n_speakers=hps.data.n_speakers,
30
+ # hps.model 以下のすべての値を引数に渡す
31
+ use_spk_conditioned_encoder=hps.model.use_spk_conditioned_encoder,
32
+ use_noise_scaled_mas=hps.model.use_noise_scaled_mas,
33
+ use_mel_posterior_encoder=hps.model.use_mel_posterior_encoder,
34
+ use_duration_discriminator=hps.model.use_duration_discriminator,
35
+ use_wavlm_discriminator=hps.model.use_wavlm_discriminator,
36
+ inter_channels=hps.model.inter_channels,
37
+ hidden_channels=hps.model.hidden_channels,
38
+ filter_channels=hps.model.filter_channels,
39
+ n_heads=hps.model.n_heads,
40
+ n_layers=hps.model.n_layers,
41
+ kernel_size=hps.model.kernel_size,
42
+ p_dropout=hps.model.p_dropout,
43
+ resblock=hps.model.resblock,
44
+ resblock_kernel_sizes=hps.model.resblock_kernel_sizes,
45
+ resblock_dilation_sizes=hps.model.resblock_dilation_sizes,
46
+ upsample_rates=hps.model.upsample_rates,
47
+ upsample_initial_channel=hps.model.upsample_initial_channel,
48
+ upsample_kernel_sizes=hps.model.upsample_kernel_sizes,
49
+ n_layers_q=hps.model.n_layers_q,
50
+ use_spectral_norm=hps.model.use_spectral_norm,
51
+ gin_channels=hps.model.gin_channels,
52
+ slm=hps.model.slm,
53
+ ).to(device)
54
+ else:
55
+ logger.info("Using normal model")
56
+ net_g = SynthesizerTrn(
57
+ n_vocab=len(SYMBOLS),
58
+ spec_channels=hps.data.filter_length // 2 + 1,
59
+ segment_size=hps.train.segment_size // hps.data.hop_length,
60
+ n_speakers=hps.data.n_speakers,
61
+ # hps.model 以下のすべての値を引数に渡す
62
+ use_spk_conditioned_encoder=hps.model.use_spk_conditioned_encoder,
63
+ use_noise_scaled_mas=hps.model.use_noise_scaled_mas,
64
+ use_mel_posterior_encoder=hps.model.use_mel_posterior_encoder,
65
+ use_duration_discriminator=hps.model.use_duration_discriminator,
66
+ use_wavlm_discriminator=hps.model.use_wavlm_discriminator,
67
+ inter_channels=hps.model.inter_channels,
68
+ hidden_channels=hps.model.hidden_channels,
69
+ filter_channels=hps.model.filter_channels,
70
+ n_heads=hps.model.n_heads,
71
+ n_layers=hps.model.n_layers,
72
+ kernel_size=hps.model.kernel_size,
73
+ p_dropout=hps.model.p_dropout,
74
+ resblock=hps.model.resblock,
75
+ resblock_kernel_sizes=hps.model.resblock_kernel_sizes,
76
+ resblock_dilation_sizes=hps.model.resblock_dilation_sizes,
77
+ upsample_rates=hps.model.upsample_rates,
78
+ upsample_initial_channel=hps.model.upsample_initial_channel,
79
+ upsample_kernel_sizes=hps.model.upsample_kernel_sizes,
80
+ n_layers_q=hps.model.n_layers_q,
81
+ use_spectral_norm=hps.model.use_spectral_norm,
82
+ gin_channels=hps.model.gin_channels,
83
+ slm=hps.model.slm,
84
+ ).to(device)
85
+ net_g.state_dict()
86
+ _ = net_g.eval()
87
+ if model_path.endswith(".pth") or model_path.endswith(".pt"):
88
+ _ = utils.checkpoints.load_checkpoint(
89
+ model_path, net_g, None, skip_optimizer=True
90
+ )
91
+ elif model_path.endswith(".safetensors"):
92
+ _ = utils.safetensors.load_safetensors(model_path, net_g, True)
93
+ else:
94
+ raise ValueError(f"Unknown model format: {model_path}")
95
+ return net_g
96
+
97
+
98
+ def get_text(
99
+ text: str,
100
+ language_str: Languages,
101
+ hps: HyperParameters,
102
+ device: str,
103
+ assist_text: Optional[str] = None,
104
+ assist_text_weight: float = 0.7,
105
+ given_phone: Optional[list[str]] = None,
106
+ given_tone: Optional[list[int]] = None,
107
+ ):
108
+ use_jp_extra = hps.version.endswith("JP-Extra")
109
+ # 推論時のみ呼び出されるので、raise_yomi_error は False に設定
110
+ norm_text, phone, tone, word2ph = clean_text(
111
+ text,
112
+ language_str,
113
+ use_jp_extra=use_jp_extra,
114
+ raise_yomi_error=False,
115
+ )
116
+ # phone と tone の両方が与えられた場合はそれを使う
117
+ if given_phone is not None and given_tone is not None:
118
+ # 指定された phone と指定された tone 両方の長さが一致していなければならない
119
+ if len(given_phone) != len(given_tone):
120
+ raise InvalidPhoneError(
121
+ f"Length of given_phone ({len(given_phone)}) != length of given_tone ({len(given_tone)})"
122
+ )
123
+ # 与えられた音素数と pyopenjtalk で生成した読みの音素数が一致しない
124
+ if len(given_phone) != sum(word2ph):
125
+ # 日本語の場合、len(given_phone) と sum(word2ph) が一致するように word2ph を適切に調整する
126
+ # 他の言語は word2ph の調整方法が思いつかないのでエラー
127
+ if language_str == Languages.JP:
128
+ from style_bert_vits2.nlp.japanese.g2p import adjust_word2ph
129
+
130
+ word2ph = adjust_word2ph(word2ph, phone, given_phone)
131
+ # 上記処理により word2ph の合計が given_phone の長さと一致するはず
132
+ # それでも一致しない場合、大半は読み上げテキストと given_phone が著しく乖離していて調整し切れなかったことを意味する
133
+ if len(given_phone) != sum(word2ph):
134
+ raise InvalidPhoneError(
135
+ f"Length of given_phone ({len(given_phone)}) != sum of word2ph ({sum(word2ph)})"
136
+ )
137
+ else:
138
+ raise InvalidPhoneError(
139
+ f"Length of given_phone ({len(given_phone)}) != sum of word2ph ({sum(word2ph)})"
140
+ )
141
+ phone = given_phone
142
+ # 生成あるいは指定された phone と指定された tone 両方の長さが一致していなければならない
143
+ if len(phone) != len(given_tone):
144
+ raise InvalidToneError(
145
+ f"Length of phone ({len(phone)}) != length of given_tone ({len(given_tone)})"
146
+ )
147
+ tone = given_tone
148
+ # tone だけが与えられた場合は clean_text() で生成した phone と合わせて使う
149
+ elif given_tone is not None:
150
+ # 生成した phone と指定された tone 両方の長さが一致していなければならない
151
+ if len(phone) != len(given_tone):
152
+ raise InvalidToneError(
153
+ f"Length of phone ({len(phone)}) != length of given_tone ({len(given_tone)})"
154
+ )
155
+ tone = given_tone
156
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
157
+
158
+ if hps.data.add_blank:
159
+ phone = commons.intersperse(phone, 0)
160
+ tone = commons.intersperse(tone, 0)
161
+ language = commons.intersperse(language, 0)
162
+ for i in range(len(word2ph)):
163
+ word2ph[i] = word2ph[i] * 2
164
+ word2ph[0] += 1
165
+ bert_ori = extract_bert_feature(
166
+ norm_text,
167
+ word2ph,
168
+ language_str,
169
+ device,
170
+ assist_text,
171
+ assist_text_weight,
172
+ )
173
+ del word2ph
174
+ assert bert_ori.shape[-1] == len(phone), phone
175
+
176
+ if language_str == Languages.ZH:
177
+ bert = bert_ori
178
+ ja_bert = torch.zeros(1024, len(phone))
179
+ en_bert = torch.zeros(1024, len(phone))
180
+ elif language_str == Languages.JP:
181
+ bert = torch.zeros(1024, len(phone))
182
+ ja_bert = bert_ori
183
+ en_bert = torch.zeros(1024, len(phone))
184
+ elif language_str == Languages.EN:
185
+ bert = torch.zeros(1024, len(phone))
186
+ ja_bert = torch.zeros(1024, len(phone))
187
+ en_bert = bert_ori
188
+ else:
189
+ raise ValueError("language_str should be ZH, JP or EN")
190
+
191
+ assert bert.shape[-1] == len(
192
+ phone
193
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
194
+
195
+ phone = torch.LongTensor(phone)
196
+ tone = torch.LongTensor(tone)
197
+ language = torch.LongTensor(language)
198
+ return bert, ja_bert, en_bert, phone, tone, language
199
+
200
+
201
+ def infer(
202
+ text: str,
203
+ style_vec: NDArray[Any],
204
+ sdp_ratio: float,
205
+ noise_scale: float,
206
+ noise_scale_w: float,
207
+ length_scale: float,
208
+ sid: int, # In the original Bert-VITS2, its speaker_name: str, but here it's id
209
+ language: Languages,
210
+ hps: HyperParameters,
211
+ net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra],
212
+ device: str,
213
+ skip_start: bool = False,
214
+ skip_end: bool = False,
215
+ assist_text: Optional[str] = None,
216
+ assist_text_weight: float = 0.7,
217
+ given_phone: Optional[list[str]] = None,
218
+ given_tone: Optional[list[int]] = None,
219
+ ):
220
+ is_jp_extra = hps.version.endswith("JP-Extra")
221
+ bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
222
+ text,
223
+ language,
224
+ hps,
225
+ device,
226
+ assist_text=assist_text,
227
+ assist_text_weight=assist_text_weight,
228
+ given_phone=given_phone,
229
+ given_tone=given_tone,
230
+ )
231
+ if skip_start:
232
+ phones = phones[3:]
233
+ tones = tones[3:]
234
+ lang_ids = lang_ids[3:]
235
+ bert = bert[:, 3:]
236
+ ja_bert = ja_bert[:, 3:]
237
+ en_bert = en_bert[:, 3:]
238
+ if skip_end:
239
+ phones = phones[:-2]
240
+ tones = tones[:-2]
241
+ lang_ids = lang_ids[:-2]
242
+ bert = bert[:, :-2]
243
+ ja_bert = ja_bert[:, :-2]
244
+ en_bert = en_bert[:, :-2]
245
+ with torch.no_grad():
246
+ x_tst = phones.to(device).unsqueeze(0)
247
+ tones = tones.to(device).unsqueeze(0)
248
+ lang_ids = lang_ids.to(device).unsqueeze(0)
249
+ bert = bert.to(device).unsqueeze(0)
250
+ ja_bert = ja_bert.to(device).unsqueeze(0)
251
+ en_bert = en_bert.to(device).unsqueeze(0)
252
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
253
+ style_vec_tensor = torch.from_numpy(style_vec).to(device).unsqueeze(0)
254
+ del phones
255
+ sid_tensor = torch.LongTensor([sid]).to(device)
256
+ if is_jp_extra:
257
+ output = cast(SynthesizerTrnJPExtra, net_g).infer(
258
+ x_tst,
259
+ x_tst_lengths,
260
+ sid_tensor,
261
+ tones,
262
+ lang_ids,
263
+ ja_bert,
264
+ style_vec=style_vec_tensor,
265
+ sdp_ratio=sdp_ratio,
266
+ noise_scale=noise_scale,
267
+ noise_scale_w=noise_scale_w,
268
+ length_scale=length_scale,
269
+ )
270
+ else:
271
+ output = cast(SynthesizerTrn, net_g).infer(
272
+ x_tst,
273
+ x_tst_lengths,
274
+ sid_tensor,
275
+ tones,
276
+ lang_ids,
277
+ bert,
278
+ ja_bert,
279
+ en_bert,
280
+ style_vec=style_vec_tensor,
281
+ sdp_ratio=sdp_ratio,
282
+ noise_scale=noise_scale,
283
+ noise_scale_w=noise_scale_w,
284
+ length_scale=length_scale,
285
+ )
286
+ audio = output[0][0, 0].data.cpu().float().numpy()
287
+ del (
288
+ x_tst,
289
+ tones,
290
+ lang_ids,
291
+ bert,
292
+ x_tst_lengths,
293
+ sid_tensor,
294
+ ja_bert,
295
+ en_bert,
296
+ style_vec,
297
+ ) # , emo
298
+ if torch.cuda.is_available():
299
+ torch.cuda.empty_cache()
300
+ return audio
301
+
302
+
303
+ class InvalidPhoneError(ValueError):
304
+ pass
305
+
306
+
307
+ class InvalidToneError(ValueError):
308
+ pass
style_bert_vits2/models/models.py ADDED
@@ -0,0 +1,1102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
9
+
10
+ from style_bert_vits2.models import attentions, commons, modules, monotonic_alignment
11
+ from style_bert_vits2.nlp.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS
12
+
13
+
14
+ class DurationDiscriminator(nn.Module): # vits2
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ filter_channels: int,
19
+ kernel_size: int,
20
+ p_dropout: float,
21
+ gin_channels: int = 0,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ self.in_channels = in_channels
26
+ self.filter_channels = filter_channels
27
+ self.kernel_size = kernel_size
28
+ self.p_dropout = p_dropout
29
+ self.gin_channels = gin_channels
30
+
31
+ self.drop = nn.Dropout(p_dropout)
32
+ self.conv_1 = nn.Conv1d(
33
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
34
+ )
35
+ self.norm_1 = modules.LayerNorm(filter_channels)
36
+ self.conv_2 = nn.Conv1d(
37
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
38
+ )
39
+ self.norm_2 = modules.LayerNorm(filter_channels)
40
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
41
+
42
+ self.pre_out_conv_1 = nn.Conv1d(
43
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
44
+ )
45
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
46
+ self.pre_out_conv_2 = nn.Conv1d(
47
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
48
+ )
49
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
50
+
51
+ if gin_channels != 0:
52
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
53
+
54
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
55
+
56
+ def forward_probability(
57
+ self,
58
+ x: torch.Tensor,
59
+ x_mask: torch.Tensor,
60
+ dur: torch.Tensor,
61
+ g: Optional[torch.Tensor] = None,
62
+ ) -> torch.Tensor:
63
+ dur = self.dur_proj(dur)
64
+ x = torch.cat([x, dur], dim=1)
65
+ x = self.pre_out_conv_1(x * x_mask)
66
+ x = torch.relu(x)
67
+ x = self.pre_out_norm_1(x)
68
+ x = self.drop(x)
69
+ x = self.pre_out_conv_2(x * x_mask)
70
+ x = torch.relu(x)
71
+ x = self.pre_out_norm_2(x)
72
+ x = self.drop(x)
73
+ x = x * x_mask
74
+ x = x.transpose(1, 2)
75
+ output_prob = self.output_layer(x)
76
+ return output_prob
77
+
78
+ def forward(
79
+ self,
80
+ x: torch.Tensor,
81
+ x_mask: torch.Tensor,
82
+ dur_r: torch.Tensor,
83
+ dur_hat: torch.Tensor,
84
+ g: Optional[torch.Tensor] = None,
85
+ ) -> list[torch.Tensor]:
86
+ x = torch.detach(x)
87
+ if g is not None:
88
+ g = torch.detach(g)
89
+ x = x + self.cond(g)
90
+ x = self.conv_1(x * x_mask)
91
+ x = torch.relu(x)
92
+ x = self.norm_1(x)
93
+ x = self.drop(x)
94
+ x = self.conv_2(x * x_mask)
95
+ x = torch.relu(x)
96
+ x = self.norm_2(x)
97
+ x = self.drop(x)
98
+
99
+ output_probs = []
100
+ for dur in [dur_r, dur_hat]:
101
+ output_prob = self.forward_probability(x, x_mask, dur, g)
102
+ output_probs.append(output_prob)
103
+
104
+ return output_probs
105
+
106
+
107
+ class TransformerCouplingBlock(nn.Module):
108
+ def __init__(
109
+ self,
110
+ channels: int,
111
+ hidden_channels: int,
112
+ filter_channels: int,
113
+ n_heads: int,
114
+ n_layers: int,
115
+ kernel_size: int,
116
+ p_dropout: float,
117
+ n_flows: int = 4,
118
+ gin_channels: int = 0,
119
+ share_parameter: bool = False,
120
+ ) -> None:
121
+ super().__init__()
122
+ self.channels = channels
123
+ self.hidden_channels = hidden_channels
124
+ self.kernel_size = kernel_size
125
+ self.n_layers = n_layers
126
+ self.n_flows = n_flows
127
+ self.gin_channels = gin_channels
128
+
129
+ self.flows = nn.ModuleList()
130
+
131
+ self.wn = (
132
+ # attentions.FFT(
133
+ # hidden_channels,
134
+ # filter_channels,
135
+ # n_heads,
136
+ # n_layers,
137
+ # kernel_size,
138
+ # p_dropout,
139
+ # isflow=True,
140
+ # gin_channels=self.gin_channels,
141
+ # )
142
+ None
143
+ if share_parameter
144
+ else None
145
+ )
146
+
147
+ for i in range(n_flows):
148
+ self.flows.append(
149
+ modules.TransformerCouplingLayer(
150
+ channels,
151
+ hidden_channels,
152
+ kernel_size,
153
+ n_layers,
154
+ n_heads,
155
+ p_dropout,
156
+ filter_channels,
157
+ mean_only=True,
158
+ wn_sharing_parameter=self.wn,
159
+ gin_channels=self.gin_channels,
160
+ )
161
+ )
162
+ self.flows.append(modules.Flip())
163
+
164
+ def forward(
165
+ self,
166
+ x: torch.Tensor,
167
+ x_mask: torch.Tensor,
168
+ g: Optional[torch.Tensor] = None,
169
+ reverse: bool = False,
170
+ ) -> torch.Tensor:
171
+ if not reverse:
172
+ for flow in self.flows:
173
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
174
+ else:
175
+ for flow in reversed(self.flows):
176
+ x = flow(x, x_mask, g=g, reverse=reverse)
177
+ return x
178
+
179
+
180
+ class StochasticDurationPredictor(nn.Module):
181
+ def __init__(
182
+ self,
183
+ in_channels: int,
184
+ filter_channels: int,
185
+ kernel_size: int,
186
+ p_dropout: float,
187
+ n_flows: int = 4,
188
+ gin_channels: int = 0,
189
+ ) -> None:
190
+ super().__init__()
191
+ filter_channels = in_channels # it needs to be removed from future version.
192
+ self.in_channels = in_channels
193
+ self.filter_channels = filter_channels
194
+ self.kernel_size = kernel_size
195
+ self.p_dropout = p_dropout
196
+ self.n_flows = n_flows
197
+ self.gin_channels = gin_channels
198
+
199
+ self.log_flow = modules.Log()
200
+ self.flows = nn.ModuleList()
201
+ self.flows.append(modules.ElementwiseAffine(2))
202
+ for i in range(n_flows):
203
+ self.flows.append(
204
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
205
+ )
206
+ self.flows.append(modules.Flip())
207
+
208
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
209
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
210
+ self.post_convs = modules.DDSConv(
211
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
212
+ )
213
+ self.post_flows = nn.ModuleList()
214
+ self.post_flows.append(modules.ElementwiseAffine(2))
215
+ for i in range(4):
216
+ self.post_flows.append(
217
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
218
+ )
219
+ self.post_flows.append(modules.Flip())
220
+
221
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
222
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
223
+ self.convs = modules.DDSConv(
224
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
225
+ )
226
+ if gin_channels != 0:
227
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
228
+
229
+ def forward(
230
+ self,
231
+ x: torch.Tensor,
232
+ x_mask: torch.Tensor,
233
+ w: Optional[torch.Tensor] = None,
234
+ g: Optional[torch.Tensor] = None,
235
+ reverse: bool = False,
236
+ noise_scale: float = 1.0,
237
+ ) -> torch.Tensor:
238
+ x = torch.detach(x)
239
+ x = self.pre(x)
240
+ if g is not None:
241
+ g = torch.detach(g)
242
+ x = x + self.cond(g)
243
+ x = self.convs(x, x_mask)
244
+ x = self.proj(x) * x_mask
245
+
246
+ if not reverse:
247
+ flows = self.flows
248
+ assert w is not None
249
+
250
+ logdet_tot_q = 0
251
+ h_w = self.post_pre(w)
252
+ h_w = self.post_convs(h_w, x_mask)
253
+ h_w = self.post_proj(h_w) * x_mask
254
+ e_q = (
255
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
256
+ * x_mask
257
+ )
258
+ z_q = e_q
259
+ for flow in self.post_flows:
260
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
261
+ logdet_tot_q += logdet_q
262
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
263
+ u = torch.sigmoid(z_u) * x_mask
264
+ z0 = (w - u) * x_mask
265
+ logdet_tot_q += torch.sum(
266
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
267
+ )
268
+ logq = (
269
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
270
+ - logdet_tot_q
271
+ )
272
+
273
+ logdet_tot = 0
274
+ z0, logdet = self.log_flow(z0, x_mask)
275
+ logdet_tot += logdet
276
+ z = torch.cat([z0, z1], 1)
277
+ for flow in flows:
278
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
279
+ logdet_tot = logdet_tot + logdet
280
+ nll = (
281
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
282
+ - logdet_tot
283
+ )
284
+ return nll + logq # [b]
285
+ else:
286
+ flows = list(reversed(self.flows))
287
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
288
+ z = (
289
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
290
+ * noise_scale
291
+ )
292
+ for flow in flows:
293
+ z = flow(z, x_mask, g=x, reverse=reverse)
294
+ z0, z1 = torch.split(z, [1, 1], 1)
295
+ logw = z0
296
+ return logw
297
+
298
+
299
+ class DurationPredictor(nn.Module):
300
+ def __init__(
301
+ self,
302
+ in_channels: int,
303
+ filter_channels: int,
304
+ kernel_size: int,
305
+ p_dropout: float,
306
+ gin_channels: int = 0,
307
+ ) -> None:
308
+ super().__init__()
309
+
310
+ self.in_channels = in_channels
311
+ self.filter_channels = filter_channels
312
+ self.kernel_size = kernel_size
313
+ self.p_dropout = p_dropout
314
+ self.gin_channels = gin_channels
315
+
316
+ self.drop = nn.Dropout(p_dropout)
317
+ self.conv_1 = nn.Conv1d(
318
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
319
+ )
320
+ self.norm_1 = modules.LayerNorm(filter_channels)
321
+ self.conv_2 = nn.Conv1d(
322
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
323
+ )
324
+ self.norm_2 = modules.LayerNorm(filter_channels)
325
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
326
+
327
+ if gin_channels != 0:
328
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
329
+
330
+ def forward(
331
+ self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
332
+ ) -> torch.Tensor:
333
+ x = torch.detach(x)
334
+ if g is not None:
335
+ g = torch.detach(g)
336
+ x = x + self.cond(g)
337
+ x = self.conv_1(x * x_mask)
338
+ x = torch.relu(x)
339
+ x = self.norm_1(x)
340
+ x = self.drop(x)
341
+ x = self.conv_2(x * x_mask)
342
+ x = torch.relu(x)
343
+ x = self.norm_2(x)
344
+ x = self.drop(x)
345
+ x = self.proj(x * x_mask)
346
+ return x * x_mask
347
+
348
+
349
+ class TextEncoder(nn.Module):
350
+ def __init__(
351
+ self,
352
+ n_vocab: int,
353
+ out_channels: int,
354
+ hidden_channels: int,
355
+ filter_channels: int,
356
+ n_heads: int,
357
+ n_layers: int,
358
+ kernel_size: int,
359
+ p_dropout: float,
360
+ n_speakers: int,
361
+ gin_channels: int = 0,
362
+ ) -> None:
363
+ super().__init__()
364
+ self.n_vocab = n_vocab
365
+ self.out_channels = out_channels
366
+ self.hidden_channels = hidden_channels
367
+ self.filter_channels = filter_channels
368
+ self.n_heads = n_heads
369
+ self.n_layers = n_layers
370
+ self.kernel_size = kernel_size
371
+ self.p_dropout = p_dropout
372
+ self.gin_channels = gin_channels
373
+ self.emb = nn.Embedding(len(SYMBOLS), hidden_channels)
374
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
375
+ self.tone_emb = nn.Embedding(NUM_TONES, hidden_channels)
376
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
377
+ self.language_emb = nn.Embedding(NUM_LANGUAGES, hidden_channels)
378
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
379
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
380
+ self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
381
+ self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
382
+ self.style_proj = nn.Linear(256, hidden_channels)
383
+
384
+ self.encoder = attentions.Encoder(
385
+ hidden_channels,
386
+ filter_channels,
387
+ n_heads,
388
+ n_layers,
389
+ kernel_size,
390
+ p_dropout,
391
+ gin_channels=self.gin_channels,
392
+ )
393
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
394
+
395
+ def forward(
396
+ self,
397
+ x: torch.Tensor,
398
+ x_lengths: torch.Tensor,
399
+ tone: torch.Tensor,
400
+ language: torch.Tensor,
401
+ bert: torch.Tensor,
402
+ ja_bert: torch.Tensor,
403
+ en_bert: torch.Tensor,
404
+ style_vec: torch.Tensor,
405
+ sid: torch.Tensor,
406
+ g: Optional[torch.Tensor] = None,
407
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
408
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
409
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
410
+ en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
411
+ style_emb = self.style_proj(style_vec.unsqueeze(1))
412
+
413
+ x = (
414
+ self.emb(x)
415
+ + self.tone_emb(tone)
416
+ + self.language_emb(language)
417
+ + bert_emb
418
+ + ja_bert_emb
419
+ + en_bert_emb
420
+ + style_emb
421
+ ) * math.sqrt(
422
+ self.hidden_channels
423
+ ) # [b, t, h]
424
+ x = torch.transpose(x, 1, -1) # [b, h, t]
425
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
426
+ x.dtype
427
+ )
428
+
429
+ x = self.encoder(x * x_mask, x_mask, g=g)
430
+ stats = self.proj(x) * x_mask
431
+
432
+ m, logs = torch.split(stats, self.out_channels, dim=1)
433
+ return x, m, logs, x_mask
434
+
435
+
436
+ class ResidualCouplingBlock(nn.Module):
437
+ def __init__(
438
+ self,
439
+ channels: int,
440
+ hidden_channels: int,
441
+ kernel_size: int,
442
+ dilation_rate: int,
443
+ n_layers: int,
444
+ n_flows: int = 4,
445
+ gin_channels: int = 0,
446
+ ) -> None:
447
+ super().__init__()
448
+ self.channels = channels
449
+ self.hidden_channels = hidden_channels
450
+ self.kernel_size = kernel_size
451
+ self.dilation_rate = dilation_rate
452
+ self.n_layers = n_layers
453
+ self.n_flows = n_flows
454
+ self.gin_channels = gin_channels
455
+
456
+ self.flows = nn.ModuleList()
457
+ for i in range(n_flows):
458
+ self.flows.append(
459
+ modules.ResidualCouplingLayer(
460
+ channels,
461
+ hidden_channels,
462
+ kernel_size,
463
+ dilation_rate,
464
+ n_layers,
465
+ gin_channels=gin_channels,
466
+ mean_only=True,
467
+ )
468
+ )
469
+ self.flows.append(modules.Flip())
470
+
471
+ def forward(
472
+ self,
473
+ x: torch.Tensor,
474
+ x_mask: torch.Tensor,
475
+ g: Optional[torch.Tensor] = None,
476
+ reverse: bool = False,
477
+ ) -> torch.Tensor:
478
+ if not reverse:
479
+ for flow in self.flows:
480
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
481
+ else:
482
+ for flow in reversed(self.flows):
483
+ x = flow(x, x_mask, g=g, reverse=reverse)
484
+ return x
485
+
486
+
487
+ class PosteriorEncoder(nn.Module):
488
+ def __init__(
489
+ self,
490
+ in_channels: int,
491
+ out_channels: int,
492
+ hidden_channels: int,
493
+ kernel_size: int,
494
+ dilation_rate: int,
495
+ n_layers: int,
496
+ gin_channels: int = 0,
497
+ ) -> None:
498
+ super().__init__()
499
+ self.in_channels = in_channels
500
+ self.out_channels = out_channels
501
+ self.hidden_channels = hidden_channels
502
+ self.kernel_size = kernel_size
503
+ self.dilation_rate = dilation_rate
504
+ self.n_layers = n_layers
505
+ self.gin_channels = gin_channels
506
+
507
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
508
+ self.enc = modules.WN(
509
+ hidden_channels,
510
+ kernel_size,
511
+ dilation_rate,
512
+ n_layers,
513
+ gin_channels=gin_channels,
514
+ )
515
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
516
+
517
+ def forward(
518
+ self,
519
+ x: torch.Tensor,
520
+ x_lengths: torch.Tensor,
521
+ g: Optional[torch.Tensor] = None,
522
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
523
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
524
+ x.dtype
525
+ )
526
+ x = self.pre(x) * x_mask
527
+ x = self.enc(x, x_mask, g=g)
528
+ stats = self.proj(x) * x_mask
529
+ m, logs = torch.split(stats, self.out_channels, dim=1)
530
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
531
+ return z, m, logs, x_mask
532
+
533
+
534
+ class Generator(torch.nn.Module):
535
+ def __init__(
536
+ self,
537
+ initial_channel: int,
538
+ resblock_str: str,
539
+ resblock_kernel_sizes: list[int],
540
+ resblock_dilation_sizes: list[list[int]],
541
+ upsample_rates: list[int],
542
+ upsample_initial_channel: int,
543
+ upsample_kernel_sizes: list[int],
544
+ gin_channels: int = 0,
545
+ ) -> None:
546
+ super(Generator, self).__init__()
547
+ self.num_kernels = len(resblock_kernel_sizes)
548
+ self.num_upsamples = len(upsample_rates)
549
+ self.conv_pre = Conv1d(
550
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
551
+ )
552
+ resblock = modules.ResBlock1 if resblock_str == "1" else modules.ResBlock2
553
+
554
+ self.ups = nn.ModuleList()
555
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
556
+ self.ups.append(
557
+ weight_norm(
558
+ ConvTranspose1d(
559
+ upsample_initial_channel // (2**i),
560
+ upsample_initial_channel // (2 ** (i + 1)),
561
+ k,
562
+ u,
563
+ padding=(k - u) // 2,
564
+ )
565
+ )
566
+ )
567
+
568
+ self.resblocks = nn.ModuleList()
569
+ ch = None
570
+ for i in range(len(self.ups)):
571
+ ch = upsample_initial_channel // (2 ** (i + 1))
572
+ for j, (k, d) in enumerate(
573
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
574
+ ):
575
+ self.resblocks.append(resblock(ch, k, d)) # type: ignore
576
+
577
+ assert ch is not None
578
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
579
+ self.ups.apply(commons.init_weights)
580
+
581
+ if gin_channels != 0:
582
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
583
+
584
+ def forward(
585
+ self, x: torch.Tensor, g: Optional[torch.Tensor] = None
586
+ ) -> torch.Tensor:
587
+ x = self.conv_pre(x)
588
+ if g is not None:
589
+ x = x + self.cond(g)
590
+
591
+ for i in range(self.num_upsamples):
592
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
593
+ x = self.ups[i](x)
594
+ xs = None
595
+ for j in range(self.num_kernels):
596
+ if xs is None:
597
+ xs = self.resblocks[i * self.num_kernels + j](x)
598
+ else:
599
+ xs += self.resblocks[i * self.num_kernels + j](x)
600
+ assert xs is not None
601
+ x = xs / self.num_kernels
602
+ x = F.leaky_relu(x)
603
+ x = self.conv_post(x)
604
+ x = torch.tanh(x)
605
+
606
+ return x
607
+
608
+ def remove_weight_norm(self) -> None:
609
+ print("Removing weight norm...")
610
+ for layer in self.ups:
611
+ remove_weight_norm(layer)
612
+ for layer in self.resblocks:
613
+ layer.remove_weight_norm()
614
+
615
+
616
+ class DiscriminatorP(torch.nn.Module):
617
+ def __init__(
618
+ self,
619
+ period: int,
620
+ kernel_size: int = 5,
621
+ stride: int = 3,
622
+ use_spectral_norm: bool = False,
623
+ ) -> None:
624
+ super(DiscriminatorP, self).__init__()
625
+ self.period = period
626
+ self.use_spectral_norm = use_spectral_norm
627
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
628
+ self.convs = nn.ModuleList(
629
+ [
630
+ norm_f(
631
+ Conv2d(
632
+ 1,
633
+ 32,
634
+ (kernel_size, 1),
635
+ (stride, 1),
636
+ padding=(commons.get_padding(kernel_size, 1), 0),
637
+ )
638
+ ),
639
+ norm_f(
640
+ Conv2d(
641
+ 32,
642
+ 128,
643
+ (kernel_size, 1),
644
+ (stride, 1),
645
+ padding=(commons.get_padding(kernel_size, 1), 0),
646
+ )
647
+ ),
648
+ norm_f(
649
+ Conv2d(
650
+ 128,
651
+ 512,
652
+ (kernel_size, 1),
653
+ (stride, 1),
654
+ padding=(commons.get_padding(kernel_size, 1), 0),
655
+ )
656
+ ),
657
+ norm_f(
658
+ Conv2d(
659
+ 512,
660
+ 1024,
661
+ (kernel_size, 1),
662
+ (stride, 1),
663
+ padding=(commons.get_padding(kernel_size, 1), 0),
664
+ )
665
+ ),
666
+ norm_f(
667
+ Conv2d(
668
+ 1024,
669
+ 1024,
670
+ (kernel_size, 1),
671
+ 1,
672
+ padding=(commons.get_padding(kernel_size, 1), 0),
673
+ )
674
+ ),
675
+ ]
676
+ )
677
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
678
+
679
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
680
+ fmap = []
681
+
682
+ # 1d to 2d
683
+ b, c, t = x.shape
684
+ if t % self.period != 0: # pad first
685
+ n_pad = self.period - (t % self.period)
686
+ x = F.pad(x, (0, n_pad), "reflect")
687
+ t = t + n_pad
688
+ x = x.view(b, c, t // self.period, self.period)
689
+
690
+ for layer in self.convs:
691
+ x = layer(x)
692
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
693
+ fmap.append(x)
694
+ x = self.conv_post(x)
695
+ fmap.append(x)
696
+ x = torch.flatten(x, 1, -1)
697
+
698
+ return x, fmap
699
+
700
+
701
+ class DiscriminatorS(torch.nn.Module):
702
+ def __init__(self, use_spectral_norm: bool = False) -> None:
703
+ super(DiscriminatorS, self).__init__()
704
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
705
+ self.convs = nn.ModuleList(
706
+ [
707
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
708
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
709
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
710
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
711
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
712
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
713
+ ]
714
+ )
715
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
716
+
717
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
718
+ fmap = []
719
+
720
+ for layer in self.convs:
721
+ x = layer(x)
722
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
723
+ fmap.append(x)
724
+ x = self.conv_post(x)
725
+ fmap.append(x)
726
+ x = torch.flatten(x, 1, -1)
727
+
728
+ return x, fmap
729
+
730
+
731
+ class MultiPeriodDiscriminator(torch.nn.Module):
732
+ def __init__(self, use_spectral_norm: bool = False) -> None:
733
+ super(MultiPeriodDiscriminator, self).__init__()
734
+ periods = [2, 3, 5, 7, 11]
735
+
736
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
737
+ discs = discs + [
738
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
739
+ ]
740
+ self.discriminators = nn.ModuleList(discs)
741
+
742
+ def forward(
743
+ self,
744
+ y: torch.Tensor,
745
+ y_hat: torch.Tensor,
746
+ ) -> tuple[
747
+ list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]
748
+ ]:
749
+ y_d_rs = []
750
+ y_d_gs = []
751
+ fmap_rs = []
752
+ fmap_gs = []
753
+ for i, d in enumerate(self.discriminators):
754
+ y_d_r, fmap_r = d(y)
755
+ y_d_g, fmap_g = d(y_hat)
756
+ y_d_rs.append(y_d_r)
757
+ y_d_gs.append(y_d_g)
758
+ fmap_rs.append(fmap_r)
759
+ fmap_gs.append(fmap_g)
760
+
761
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
762
+
763
+
764
+ class ReferenceEncoder(nn.Module):
765
+ """
766
+ inputs --- [N, Ty/r, n_mels*r] mels
767
+ outputs --- [N, ref_enc_gru_size]
768
+ """
769
+
770
+ def __init__(self, spec_channels: int, gin_channels: int = 0) -> None:
771
+ super().__init__()
772
+ self.spec_channels = spec_channels
773
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
774
+ K = len(ref_enc_filters)
775
+ filters = [1] + ref_enc_filters
776
+ convs = [
777
+ weight_norm(
778
+ nn.Conv2d(
779
+ in_channels=filters[i],
780
+ out_channels=filters[i + 1],
781
+ kernel_size=(3, 3),
782
+ stride=(2, 2),
783
+ padding=(1, 1),
784
+ )
785
+ )
786
+ for i in range(K)
787
+ ]
788
+ self.convs = nn.ModuleList(convs)
789
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
790
+
791
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
792
+ self.gru = nn.GRU(
793
+ input_size=ref_enc_filters[-1] * out_channels,
794
+ hidden_size=256 // 2,
795
+ batch_first=True,
796
+ )
797
+ self.proj = nn.Linear(128, gin_channels)
798
+
799
+ def forward(
800
+ self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None
801
+ ) -> torch.Tensor:
802
+ N = inputs.size(0)
803
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
804
+ for conv in self.convs:
805
+ out = conv(out)
806
+ # out = wn(out)
807
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
808
+
809
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
810
+ T = out.size(1)
811
+ N = out.size(0)
812
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
813
+
814
+ self.gru.flatten_parameters()
815
+ memory, out = self.gru(out) # out --- [1, N, 128]
816
+
817
+ return self.proj(out.squeeze(0))
818
+
819
+ def calculate_channels(
820
+ self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int
821
+ ) -> int:
822
+ for i in range(n_convs):
823
+ L = (L - kernel_size + 2 * pad) // stride + 1
824
+ return L
825
+
826
+
827
+ class SynthesizerTrn(nn.Module):
828
+ """
829
+ Synthesizer for Training
830
+ """
831
+
832
+ def __init__(
833
+ self,
834
+ n_vocab: int,
835
+ spec_channels: int,
836
+ segment_size: int,
837
+ inter_channels: int,
838
+ hidden_channels: int,
839
+ filter_channels: int,
840
+ n_heads: int,
841
+ n_layers: int,
842
+ kernel_size: int,
843
+ p_dropout: float,
844
+ resblock: str,
845
+ resblock_kernel_sizes: list[int],
846
+ resblock_dilation_sizes: list[list[int]],
847
+ upsample_rates: list[int],
848
+ upsample_initial_channel: int,
849
+ upsample_kernel_sizes: list[int],
850
+ n_speakers: int = 256,
851
+ gin_channels: int = 256,
852
+ use_sdp: bool = True,
853
+ n_flow_layer: int = 4,
854
+ n_layers_trans_flow: int = 4,
855
+ flow_share_parameter: bool = False,
856
+ use_transformer_flow: bool = True,
857
+ **kwargs: Any,
858
+ ) -> None:
859
+ super().__init__()
860
+ self.n_vocab = n_vocab
861
+ self.spec_channels = spec_channels
862
+ self.inter_channels = inter_channels
863
+ self.hidden_channels = hidden_channels
864
+ self.filter_channels = filter_channels
865
+ self.n_heads = n_heads
866
+ self.n_layers = n_layers
867
+ self.kernel_size = kernel_size
868
+ self.p_dropout = p_dropout
869
+ self.resblock = resblock
870
+ self.resblock_kernel_sizes = resblock_kernel_sizes
871
+ self.resblock_dilation_sizes = resblock_dilation_sizes
872
+ self.upsample_rates = upsample_rates
873
+ self.upsample_initial_channel = upsample_initial_channel
874
+ self.upsample_kernel_sizes = upsample_kernel_sizes
875
+ self.segment_size = segment_size
876
+ self.n_speakers = n_speakers
877
+ self.gin_channels = gin_channels
878
+ self.n_layers_trans_flow = n_layers_trans_flow
879
+ self.use_spk_conditioned_encoder = kwargs.get(
880
+ "use_spk_conditioned_encoder", True
881
+ )
882
+ self.use_sdp = use_sdp
883
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
884
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
885
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
886
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
887
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
888
+ self.enc_gin_channels = gin_channels
889
+ self.enc_p = TextEncoder(
890
+ n_vocab,
891
+ inter_channels,
892
+ hidden_channels,
893
+ filter_channels,
894
+ n_heads,
895
+ n_layers,
896
+ kernel_size,
897
+ p_dropout,
898
+ self.n_speakers,
899
+ gin_channels=self.enc_gin_channels,
900
+ )
901
+ self.dec = Generator(
902
+ inter_channels,
903
+ resblock,
904
+ resblock_kernel_sizes,
905
+ resblock_dilation_sizes,
906
+ upsample_rates,
907
+ upsample_initial_channel,
908
+ upsample_kernel_sizes,
909
+ gin_channels=gin_channels,
910
+ )
911
+ self.enc_q = PosteriorEncoder(
912
+ spec_channels,
913
+ inter_channels,
914
+ hidden_channels,
915
+ 5,
916
+ 1,
917
+ 16,
918
+ gin_channels=gin_channels,
919
+ )
920
+ if use_transformer_flow:
921
+ self.flow = TransformerCouplingBlock(
922
+ inter_channels,
923
+ hidden_channels,
924
+ filter_channels,
925
+ n_heads,
926
+ n_layers_trans_flow,
927
+ 5,
928
+ p_dropout,
929
+ n_flow_layer,
930
+ gin_channels=gin_channels,
931
+ share_parameter=flow_share_parameter,
932
+ )
933
+ else:
934
+ self.flow = ResidualCouplingBlock(
935
+ inter_channels,
936
+ hidden_channels,
937
+ 5,
938
+ 1,
939
+ n_flow_layer,
940
+ gin_channels=gin_channels,
941
+ )
942
+ self.sdp = StochasticDurationPredictor(
943
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
944
+ )
945
+ self.dp = DurationPredictor(
946
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
947
+ )
948
+
949
+ if n_speakers >= 1:
950
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
951
+ else:
952
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
953
+
954
+ def forward(
955
+ self,
956
+ x: torch.Tensor,
957
+ x_lengths: torch.Tensor,
958
+ y: torch.Tensor,
959
+ y_lengths: torch.Tensor,
960
+ sid: torch.Tensor,
961
+ tone: torch.Tensor,
962
+ language: torch.Tensor,
963
+ bert: torch.Tensor,
964
+ ja_bert: torch.Tensor,
965
+ en_bert: torch.Tensor,
966
+ style_vec: torch.Tensor,
967
+ ) -> tuple[
968
+ torch.Tensor,
969
+ torch.Tensor,
970
+ torch.Tensor,
971
+ torch.Tensor,
972
+ torch.Tensor,
973
+ torch.Tensor,
974
+ tuple[torch.Tensor, ...],
975
+ tuple[torch.Tensor, ...],
976
+ ]:
977
+ if self.n_speakers > 0:
978
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
979
+ else:
980
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
981
+ x, m_p, logs_p, x_mask = self.enc_p(
982
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, style_vec, sid, g=g
983
+ )
984
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
985
+ z_p = self.flow(z, y_mask, g=g)
986
+
987
+ with torch.no_grad():
988
+ # negative cross-entropy
989
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
990
+ neg_cent1 = torch.sum(
991
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
992
+ ) # [b, 1, t_s]
993
+ neg_cent2 = torch.matmul(
994
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
995
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
996
+ neg_cent3 = torch.matmul(
997
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
998
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
999
+ neg_cent4 = torch.sum(
1000
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
1001
+ ) # [b, 1, t_s]
1002
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
1003
+ if self.use_noise_scaled_mas:
1004
+ epsilon = (
1005
+ torch.std(neg_cent)
1006
+ * torch.randn_like(neg_cent)
1007
+ * self.current_mas_noise_scale
1008
+ )
1009
+ neg_cent = neg_cent + epsilon
1010
+
1011
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1012
+ attn = (
1013
+ monotonic_alignment.maximum_path(neg_cent, attn_mask.squeeze(1))
1014
+ .unsqueeze(1)
1015
+ .detach()
1016
+ )
1017
+
1018
+ w = attn.sum(2)
1019
+
1020
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
1021
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
1022
+
1023
+ logw_ = torch.log(w + 1e-6) * x_mask
1024
+ logw = self.dp(x, x_mask, g=g)
1025
+ # logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
1026
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
1027
+ x_mask
1028
+ ) # for averaging
1029
+ # l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
1030
+
1031
+ l_length = l_length_dp + l_length_sdp
1032
+
1033
+ # expand prior
1034
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
1035
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
1036
+
1037
+ z_slice, ids_slice = commons.rand_slice_segments(
1038
+ z, y_lengths, self.segment_size
1039
+ )
1040
+ o = self.dec(z_slice, g=g)
1041
+ return (
1042
+ o,
1043
+ l_length,
1044
+ attn,
1045
+ ids_slice,
1046
+ x_mask,
1047
+ y_mask,
1048
+ (z, z_p, m_p, logs_p, m_q, logs_q),
1049
+ (x, logw, logw_),
1050
+ )
1051
+
1052
+ def infer(
1053
+ self,
1054
+ x: torch.Tensor,
1055
+ x_lengths: torch.Tensor,
1056
+ sid: torch.Tensor,
1057
+ tone: torch.Tensor,
1058
+ language: torch.Tensor,
1059
+ bert: torch.Tensor,
1060
+ ja_bert: torch.Tensor,
1061
+ en_bert: torch.Tensor,
1062
+ style_vec: torch.Tensor,
1063
+ noise_scale: float = 0.667,
1064
+ length_scale: float = 1.0,
1065
+ noise_scale_w: float = 0.8,
1066
+ max_len: Optional[int] = None,
1067
+ sdp_ratio: float = 0.0,
1068
+ y: Optional[torch.Tensor] = None,
1069
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, ...]]:
1070
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
1071
+ # g = self.gst(y)
1072
+ if self.n_speakers > 0:
1073
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1074
+ else:
1075
+ assert y is not None
1076
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1077
+ x, m_p, logs_p, x_mask = self.enc_p(
1078
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, style_vec, sid, g=g
1079
+ )
1080
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1081
+ sdp_ratio
1082
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1083
+ w = torch.exp(logw) * x_mask * length_scale
1084
+ w_ceil = torch.ceil(w)
1085
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1086
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1087
+ x_mask.dtype
1088
+ )
1089
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1090
+ attn = commons.generate_path(w_ceil, attn_mask)
1091
+
1092
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1093
+ 1, 2
1094
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1095
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1096
+ 1, 2
1097
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1098
+
1099
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1100
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1101
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1102
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
style_bert_vits2/models/models_jp_extra.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
9
+
10
+ from style_bert_vits2.models import attentions, commons, modules, monotonic_alignment
11
+ from style_bert_vits2.nlp.symbols import NUM_LANGUAGES, NUM_TONES, SYMBOLS
12
+
13
+
14
+ class DurationDiscriminator(nn.Module): # vits2
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ filter_channels: int,
19
+ kernel_size: int,
20
+ p_dropout: float,
21
+ gin_channels: int = 0,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+ self.in_channels = in_channels
26
+ self.filter_channels = filter_channels
27
+ self.kernel_size = kernel_size
28
+ self.p_dropout = p_dropout
29
+ self.gin_channels = gin_channels
30
+
31
+ self.drop = nn.Dropout(p_dropout)
32
+ self.conv_1 = nn.Conv1d(
33
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
34
+ )
35
+ self.norm_1 = modules.LayerNorm(filter_channels)
36
+ self.conv_2 = nn.Conv1d(
37
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
38
+ )
39
+ self.norm_2 = modules.LayerNorm(filter_channels)
40
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
41
+
42
+ self.LSTM = nn.LSTM(
43
+ 2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
44
+ )
45
+
46
+ if gin_channels != 0:
47
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
48
+
49
+ self.output_layer = nn.Sequential(
50
+ nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
51
+ )
52
+
53
+ def forward_probability(self, x: torch.Tensor, dur: torch.Tensor) -> torch.Tensor:
54
+ dur = self.dur_proj(dur)
55
+ x = torch.cat([x, dur], dim=1)
56
+ x = x.transpose(1, 2)
57
+ x, _ = self.LSTM(x)
58
+ output_prob = self.output_layer(x)
59
+ return output_prob
60
+
61
+ def forward(
62
+ self,
63
+ x: torch.Tensor,
64
+ x_mask: torch.Tensor,
65
+ dur_r: torch.Tensor,
66
+ dur_hat: torch.Tensor,
67
+ g: Optional[torch.Tensor] = None,
68
+ ) -> list[torch.Tensor]:
69
+ x = torch.detach(x)
70
+ if g is not None:
71
+ g = torch.detach(g)
72
+ x = x + self.cond(g)
73
+ x = self.conv_1(x * x_mask)
74
+ x = torch.relu(x)
75
+ x = self.norm_1(x)
76
+ x = self.drop(x)
77
+ x = self.conv_2(x * x_mask)
78
+ x = torch.relu(x)
79
+ x = self.norm_2(x)
80
+ x = self.drop(x)
81
+
82
+ output_probs = []
83
+ for dur in [dur_r, dur_hat]:
84
+ output_prob = self.forward_probability(x, dur)
85
+ output_probs.append(output_prob)
86
+
87
+ return output_probs
88
+
89
+
90
+ class TransformerCouplingBlock(nn.Module):
91
+ def __init__(
92
+ self,
93
+ channels: int,
94
+ hidden_channels: int,
95
+ filter_channels: int,
96
+ n_heads: int,
97
+ n_layers: int,
98
+ kernel_size: int,
99
+ p_dropout: float,
100
+ n_flows: int = 4,
101
+ gin_channels: int = 0,
102
+ share_parameter: bool = False,
103
+ ) -> None:
104
+ super().__init__()
105
+ self.channels = channels
106
+ self.hidden_channels = hidden_channels
107
+ self.kernel_size = kernel_size
108
+ self.n_layers = n_layers
109
+ self.n_flows = n_flows
110
+ self.gin_channels = gin_channels
111
+
112
+ self.flows = nn.ModuleList()
113
+
114
+ self.wn = (
115
+ # attentions.FFT(
116
+ # hidden_channels,
117
+ # filter_channels,
118
+ # n_heads,
119
+ # n_layers,
120
+ # kernel_size,
121
+ # p_dropout,
122
+ # isflow=True,
123
+ # gin_channels=self.gin_channels,
124
+ # )
125
+ None
126
+ if share_parameter
127
+ else None
128
+ )
129
+
130
+ for i in range(n_flows):
131
+ self.flows.append(
132
+ modules.TransformerCouplingLayer(
133
+ channels,
134
+ hidden_channels,
135
+ kernel_size,
136
+ n_layers,
137
+ n_heads,
138
+ p_dropout,
139
+ filter_channels,
140
+ mean_only=True,
141
+ wn_sharing_parameter=self.wn,
142
+ gin_channels=self.gin_channels,
143
+ )
144
+ )
145
+ self.flows.append(modules.Flip())
146
+
147
+ def forward(
148
+ self,
149
+ x: torch.Tensor,
150
+ x_mask: torch.Tensor,
151
+ g: Optional[torch.Tensor] = None,
152
+ reverse: bool = False,
153
+ ) -> torch.Tensor:
154
+ if not reverse:
155
+ for flow in self.flows:
156
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
157
+ else:
158
+ for flow in reversed(self.flows):
159
+ x = flow(x, x_mask, g=g, reverse=reverse)
160
+ return x
161
+
162
+
163
+ class StochasticDurationPredictor(nn.Module):
164
+ def __init__(
165
+ self,
166
+ in_channels: int,
167
+ filter_channels: int,
168
+ kernel_size: int,
169
+ p_dropout: float,
170
+ n_flows: int = 4,
171
+ gin_channels: int = 0,
172
+ ) -> None:
173
+ super().__init__()
174
+ filter_channels = in_channels # it needs to be removed from future version.
175
+ self.in_channels = in_channels
176
+ self.filter_channels = filter_channels
177
+ self.kernel_size = kernel_size
178
+ self.p_dropout = p_dropout
179
+ self.n_flows = n_flows
180
+ self.gin_channels = gin_channels
181
+
182
+ self.log_flow = modules.Log()
183
+ self.flows = nn.ModuleList()
184
+ self.flows.append(modules.ElementwiseAffine(2))
185
+ for i in range(n_flows):
186
+ self.flows.append(
187
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
188
+ )
189
+ self.flows.append(modules.Flip())
190
+
191
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
192
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
193
+ self.post_convs = modules.DDSConv(
194
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
195
+ )
196
+ self.post_flows = nn.ModuleList()
197
+ self.post_flows.append(modules.ElementwiseAffine(2))
198
+ for i in range(4):
199
+ self.post_flows.append(
200
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
201
+ )
202
+ self.post_flows.append(modules.Flip())
203
+
204
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
205
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
206
+ self.convs = modules.DDSConv(
207
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
208
+ )
209
+ if gin_channels != 0:
210
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
211
+
212
+ def forward(
213
+ self,
214
+ x: torch.Tensor,
215
+ x_mask: torch.Tensor,
216
+ w: Optional[torch.Tensor] = None,
217
+ g: Optional[torch.Tensor] = None,
218
+ reverse: bool = False,
219
+ noise_scale: float = 1.0,
220
+ ) -> torch.Tensor:
221
+ x = torch.detach(x)
222
+ x = self.pre(x)
223
+ if g is not None:
224
+ g = torch.detach(g)
225
+ x = x + self.cond(g)
226
+ x = self.convs(x, x_mask)
227
+ x = self.proj(x) * x_mask
228
+
229
+ if not reverse:
230
+ flows = self.flows
231
+ assert w is not None
232
+
233
+ logdet_tot_q = 0
234
+ h_w = self.post_pre(w)
235
+ h_w = self.post_convs(h_w, x_mask)
236
+ h_w = self.post_proj(h_w) * x_mask
237
+ e_q = (
238
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
239
+ * x_mask
240
+ )
241
+ z_q = e_q
242
+ for flow in self.post_flows:
243
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
244
+ logdet_tot_q += logdet_q
245
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
246
+ u = torch.sigmoid(z_u) * x_mask
247
+ z0 = (w - u) * x_mask
248
+ logdet_tot_q += torch.sum(
249
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
250
+ )
251
+ logq = (
252
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
253
+ - logdet_tot_q
254
+ )
255
+
256
+ logdet_tot = 0
257
+ z0, logdet = self.log_flow(z0, x_mask)
258
+ logdet_tot += logdet
259
+ z = torch.cat([z0, z1], 1)
260
+ for flow in flows:
261
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
262
+ logdet_tot = logdet_tot + logdet
263
+ nll = (
264
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
265
+ - logdet_tot
266
+ )
267
+ return nll + logq # [b]
268
+ else:
269
+ flows = list(reversed(self.flows))
270
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
271
+ z = (
272
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
273
+ * noise_scale
274
+ )
275
+ for flow in flows:
276
+ z = flow(z, x_mask, g=x, reverse=reverse)
277
+ z0, z1 = torch.split(z, [1, 1], 1)
278
+ logw = z0
279
+ return logw
280
+
281
+
282
+ class DurationPredictor(nn.Module):
283
+ def __init__(
284
+ self,
285
+ in_channels: int,
286
+ filter_channels: int,
287
+ kernel_size: int,
288
+ p_dropout: float,
289
+ gin_channels: int = 0,
290
+ ) -> None:
291
+ super().__init__()
292
+
293
+ self.in_channels = in_channels
294
+ self.filter_channels = filter_channels
295
+ self.kernel_size = kernel_size
296
+ self.p_dropout = p_dropout
297
+ self.gin_channels = gin_channels
298
+
299
+ self.drop = nn.Dropout(p_dropout)
300
+ self.conv_1 = nn.Conv1d(
301
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
302
+ )
303
+ self.norm_1 = modules.LayerNorm(filter_channels)
304
+ self.conv_2 = nn.Conv1d(
305
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
306
+ )
307
+ self.norm_2 = modules.LayerNorm(filter_channels)
308
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
309
+
310
+ if gin_channels != 0:
311
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
312
+
313
+ def forward(
314
+ self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
315
+ ) -> torch.Tensor:
316
+ x = torch.detach(x)
317
+ if g is not None:
318
+ g = torch.detach(g)
319
+ x = x + self.cond(g)
320
+ x = self.conv_1(x * x_mask)
321
+ x = torch.relu(x)
322
+ x = self.norm_1(x)
323
+ x = self.drop(x)
324
+ x = self.conv_2(x * x_mask)
325
+ x = torch.relu(x)
326
+ x = self.norm_2(x)
327
+ x = self.drop(x)
328
+ x = self.proj(x * x_mask)
329
+ return x * x_mask
330
+
331
+
332
+ class Bottleneck(nn.Sequential):
333
+ def __init__(self, in_dim: int, hidden_dim: int) -> None:
334
+ c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
335
+ c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
336
+ super().__init__(c_fc1, c_fc2)
337
+
338
+
339
+ class Block(nn.Module):
340
+ def __init__(self, in_dim: int, hidden_dim: int) -> None:
341
+ super().__init__()
342
+ self.norm = nn.LayerNorm(in_dim)
343
+ self.mlp = MLP(in_dim, hidden_dim)
344
+
345
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
346
+ x = x + self.mlp(self.norm(x))
347
+ return x
348
+
349
+
350
+ class MLP(nn.Module):
351
+ def __init__(self, in_dim: int, hidden_dim: int) -> None:
352
+ super().__init__()
353
+ self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
354
+ self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
355
+ self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
356
+
357
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
358
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
359
+ x = self.c_proj(x)
360
+ return x
361
+
362
+
363
+ class TextEncoder(nn.Module):
364
+ def __init__(
365
+ self,
366
+ n_vocab: int,
367
+ out_channels: int,
368
+ hidden_channels: int,
369
+ filter_channels: int,
370
+ n_heads: int,
371
+ n_layers: int,
372
+ kernel_size: int,
373
+ p_dropout: float,
374
+ gin_channels: int = 0,
375
+ ) -> None:
376
+ super().__init__()
377
+ self.n_vocab = n_vocab
378
+ self.out_channels = out_channels
379
+ self.hidden_channels = hidden_channels
380
+ self.filter_channels = filter_channels
381
+ self.n_heads = n_heads
382
+ self.n_layers = n_layers
383
+ self.kernel_size = kernel_size
384
+ self.p_dropout = p_dropout
385
+ self.gin_channels = gin_channels
386
+ self.emb = nn.Embedding(len(SYMBOLS), hidden_channels)
387
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
388
+ self.tone_emb = nn.Embedding(NUM_TONES, hidden_channels)
389
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
390
+ self.language_emb = nn.Embedding(NUM_LANGUAGES, hidden_channels)
391
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
392
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
393
+
394
+ # Remove emo_vq since it's not working well.
395
+ self.style_proj = nn.Linear(256, hidden_channels)
396
+
397
+ self.encoder = attentions.Encoder(
398
+ hidden_channels,
399
+ filter_channels,
400
+ n_heads,
401
+ n_layers,
402
+ kernel_size,
403
+ p_dropout,
404
+ gin_channels=self.gin_channels,
405
+ )
406
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
407
+
408
+ def forward(
409
+ self,
410
+ x: torch.Tensor,
411
+ x_lengths: torch.Tensor,
412
+ tone: torch.Tensor,
413
+ language: torch.Tensor,
414
+ bert: torch.Tensor,
415
+ style_vec: torch.Tensor,
416
+ g: Optional[torch.Tensor] = None,
417
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
418
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
419
+ style_emb = self.style_proj(style_vec.unsqueeze(1))
420
+ x = (
421
+ self.emb(x)
422
+ + self.tone_emb(tone)
423
+ + self.language_emb(language)
424
+ + bert_emb
425
+ + style_emb
426
+ ) * math.sqrt(
427
+ self.hidden_channels
428
+ ) # [b, t, h]
429
+ x = torch.transpose(x, 1, -1) # [b, h, t]
430
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
431
+ x.dtype
432
+ )
433
+
434
+ x = self.encoder(x * x_mask, x_mask, g=g)
435
+ stats = self.proj(x) * x_mask
436
+
437
+ m, logs = torch.split(stats, self.out_channels, dim=1)
438
+ return x, m, logs, x_mask
439
+
440
+
441
+ class ResidualCouplingBlock(nn.Module):
442
+ def __init__(
443
+ self,
444
+ channels: int,
445
+ hidden_channels: int,
446
+ kernel_size: int,
447
+ dilation_rate: int,
448
+ n_layers: int,
449
+ n_flows: int = 4,
450
+ gin_channels: int = 0,
451
+ ) -> None:
452
+ super().__init__()
453
+ self.channels = channels
454
+ self.hidden_channels = hidden_channels
455
+ self.kernel_size = kernel_size
456
+ self.dilation_rate = dilation_rate
457
+ self.n_layers = n_layers
458
+ self.n_flows = n_flows
459
+ self.gin_channels = gin_channels
460
+
461
+ self.flows = nn.ModuleList()
462
+ for i in range(n_flows):
463
+ self.flows.append(
464
+ modules.ResidualCouplingLayer(
465
+ channels,
466
+ hidden_channels,
467
+ kernel_size,
468
+ dilation_rate,
469
+ n_layers,
470
+ gin_channels=gin_channels,
471
+ mean_only=True,
472
+ )
473
+ )
474
+ self.flows.append(modules.Flip())
475
+
476
+ def forward(
477
+ self,
478
+ x: torch.Tensor,
479
+ x_mask: torch.Tensor,
480
+ g: Optional[torch.Tensor] = None,
481
+ reverse: bool = False,
482
+ ) -> torch.Tensor:
483
+ if not reverse:
484
+ for flow in self.flows:
485
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
486
+ else:
487
+ for flow in reversed(self.flows):
488
+ x = flow(x, x_mask, g=g, reverse=reverse)
489
+ return x
490
+
491
+
492
+ class PosteriorEncoder(nn.Module):
493
+ def __init__(
494
+ self,
495
+ in_channels: int,
496
+ out_channels: int,
497
+ hidden_channels: int,
498
+ kernel_size: int,
499
+ dilation_rate: int,
500
+ n_layers: int,
501
+ gin_channels: int = 0,
502
+ ) -> None:
503
+ super().__init__()
504
+ self.in_channels = in_channels
505
+ self.out_channels = out_channels
506
+ self.hidden_channels = hidden_channels
507
+ self.kernel_size = kernel_size
508
+ self.dilation_rate = dilation_rate
509
+ self.n_layers = n_layers
510
+ self.gin_channels = gin_channels
511
+
512
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
513
+ self.enc = modules.WN(
514
+ hidden_channels,
515
+ kernel_size,
516
+ dilation_rate,
517
+ n_layers,
518
+ gin_channels=gin_channels,
519
+ )
520
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
521
+
522
+ def forward(
523
+ self,
524
+ x: torch.Tensor,
525
+ x_lengths: torch.Tensor,
526
+ g: Optional[torch.Tensor] = None,
527
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
528
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
529
+ x.dtype
530
+ )
531
+ x = self.pre(x) * x_mask
532
+ x = self.enc(x, x_mask, g=g)
533
+ stats = self.proj(x) * x_mask
534
+ m, logs = torch.split(stats, self.out_channels, dim=1)
535
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
536
+ return z, m, logs, x_mask
537
+
538
+
539
+ class Generator(torch.nn.Module):
540
+ def __init__(
541
+ self,
542
+ initial_channel: int,
543
+ resblock_str: str,
544
+ resblock_kernel_sizes: list[int],
545
+ resblock_dilation_sizes: list[list[int]],
546
+ upsample_rates: list[int],
547
+ upsample_initial_channel: int,
548
+ upsample_kernel_sizes: list[int],
549
+ gin_channels: int = 0,
550
+ ) -> None:
551
+ super(Generator, self).__init__()
552
+ self.num_kernels = len(resblock_kernel_sizes)
553
+ self.num_upsamples = len(upsample_rates)
554
+ self.conv_pre = Conv1d(
555
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
556
+ )
557
+ resblock = modules.ResBlock1 if resblock_str == "1" else modules.ResBlock2
558
+
559
+ self.ups = nn.ModuleList()
560
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
561
+ self.ups.append(
562
+ weight_norm(
563
+ ConvTranspose1d(
564
+ upsample_initial_channel // (2**i),
565
+ upsample_initial_channel // (2 ** (i + 1)),
566
+ k,
567
+ u,
568
+ padding=(k - u) // 2,
569
+ )
570
+ )
571
+ )
572
+
573
+ self.resblocks = nn.ModuleList()
574
+ ch = None
575
+ for i in range(len(self.ups)):
576
+ ch = upsample_initial_channel // (2 ** (i + 1))
577
+ for j, (k, d) in enumerate(
578
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
579
+ ):
580
+ self.resblocks.append(resblock(ch, k, d)) # type: ignore
581
+
582
+ assert ch is not None
583
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
584
+ self.ups.apply(commons.init_weights)
585
+
586
+ if gin_channels != 0:
587
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
588
+
589
+ def forward(
590
+ self, x: torch.Tensor, g: Optional[torch.Tensor] = None
591
+ ) -> torch.Tensor:
592
+ x = self.conv_pre(x)
593
+ if g is not None:
594
+ x = x + self.cond(g)
595
+
596
+ for i in range(self.num_upsamples):
597
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
598
+ x = self.ups[i](x)
599
+ xs = None
600
+ for j in range(self.num_kernels):
601
+ if xs is None:
602
+ xs = self.resblocks[i * self.num_kernels + j](x)
603
+ else:
604
+ xs += self.resblocks[i * self.num_kernels + j](x)
605
+ assert xs is not None
606
+ x = xs / self.num_kernels
607
+ x = F.leaky_relu(x)
608
+ x = self.conv_post(x)
609
+ x = torch.tanh(x)
610
+
611
+ return x
612
+
613
+ def remove_weight_norm(self) -> None:
614
+ print("Removing weight norm...")
615
+ for layer in self.ups:
616
+ remove_weight_norm(layer)
617
+ for layer in self.resblocks:
618
+ layer.remove_weight_norm()
619
+
620
+
621
+ class DiscriminatorP(torch.nn.Module):
622
+ def __init__(
623
+ self,
624
+ period: int,
625
+ kernel_size: int = 5,
626
+ stride: int = 3,
627
+ use_spectral_norm: bool = False,
628
+ ) -> None:
629
+ super(DiscriminatorP, self).__init__()
630
+ self.period = period
631
+ self.use_spectral_norm = use_spectral_norm
632
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
633
+ self.convs = nn.ModuleList(
634
+ [
635
+ norm_f(
636
+ Conv2d(
637
+ 1,
638
+ 32,
639
+ (kernel_size, 1),
640
+ (stride, 1),
641
+ padding=(commons.get_padding(kernel_size, 1), 0),
642
+ )
643
+ ),
644
+ norm_f(
645
+ Conv2d(
646
+ 32,
647
+ 128,
648
+ (kernel_size, 1),
649
+ (stride, 1),
650
+ padding=(commons.get_padding(kernel_size, 1), 0),
651
+ )
652
+ ),
653
+ norm_f(
654
+ Conv2d(
655
+ 128,
656
+ 512,
657
+ (kernel_size, 1),
658
+ (stride, 1),
659
+ padding=(commons.get_padding(kernel_size, 1), 0),
660
+ )
661
+ ),
662
+ norm_f(
663
+ Conv2d(
664
+ 512,
665
+ 1024,
666
+ (kernel_size, 1),
667
+ (stride, 1),
668
+ padding=(commons.get_padding(kernel_size, 1), 0),
669
+ )
670
+ ),
671
+ norm_f(
672
+ Conv2d(
673
+ 1024,
674
+ 1024,
675
+ (kernel_size, 1),
676
+ 1,
677
+ padding=(commons.get_padding(kernel_size, 1), 0),
678
+ )
679
+ ),
680
+ ]
681
+ )
682
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
683
+
684
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
685
+ fmap = []
686
+
687
+ # 1d to 2d
688
+ b, c, t = x.shape
689
+ if t % self.period != 0: # pad first
690
+ n_pad = self.period - (t % self.period)
691
+ x = F.pad(x, (0, n_pad), "reflect")
692
+ t = t + n_pad
693
+ x = x.view(b, c, t // self.period, self.period)
694
+
695
+ for layer in self.convs:
696
+ x = layer(x)
697
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
698
+ fmap.append(x)
699
+ x = self.conv_post(x)
700
+ fmap.append(x)
701
+ x = torch.flatten(x, 1, -1)
702
+
703
+ return x, fmap
704
+
705
+
706
+ class DiscriminatorS(torch.nn.Module):
707
+ def __init__(self, use_spectral_norm: bool = False) -> None:
708
+ super(DiscriminatorS, self).__init__()
709
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
710
+ self.convs = nn.ModuleList(
711
+ [
712
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
713
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
714
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
715
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
716
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
717
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
718
+ ]
719
+ )
720
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
721
+
722
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, list[torch.Tensor]]:
723
+ fmap = []
724
+
725
+ for layer in self.convs:
726
+ x = layer(x)
727
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
728
+ fmap.append(x)
729
+ x = self.conv_post(x)
730
+ fmap.append(x)
731
+ x = torch.flatten(x, 1, -1)
732
+
733
+ return x, fmap
734
+
735
+
736
+ class MultiPeriodDiscriminator(torch.nn.Module):
737
+ def __init__(self, use_spectral_norm: bool = False) -> None:
738
+ super(MultiPeriodDiscriminator, self).__init__()
739
+ periods = [2, 3, 5, 7, 11]
740
+
741
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
742
+ discs = discs + [
743
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
744
+ ]
745
+ self.discriminators = nn.ModuleList(discs)
746
+
747
+ def forward(
748
+ self,
749
+ y: torch.Tensor,
750
+ y_hat: torch.Tensor,
751
+ ) -> tuple[
752
+ list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]
753
+ ]:
754
+ y_d_rs = []
755
+ y_d_gs = []
756
+ fmap_rs = []
757
+ fmap_gs = []
758
+ for i, d in enumerate(self.discriminators):
759
+ y_d_r, fmap_r = d(y)
760
+ y_d_g, fmap_g = d(y_hat)
761
+ y_d_rs.append(y_d_r)
762
+ y_d_gs.append(y_d_g)
763
+ fmap_rs.append(fmap_r)
764
+ fmap_gs.append(fmap_g)
765
+
766
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
767
+
768
+
769
+ class WavLMDiscriminator(nn.Module):
770
+ """docstring for Discriminator."""
771
+
772
+ def __init__(
773
+ self,
774
+ slm_hidden: int = 768,
775
+ slm_layers: int = 13,
776
+ initial_channel: int = 64,
777
+ use_spectral_norm: bool = False,
778
+ ) -> None:
779
+ super(WavLMDiscriminator, self).__init__()
780
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
781
+ self.pre = norm_f(
782
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
783
+ )
784
+
785
+ self.convs = nn.ModuleList(
786
+ [
787
+ norm_f(
788
+ nn.Conv1d(
789
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
790
+ )
791
+ ),
792
+ norm_f(
793
+ nn.Conv1d(
794
+ initial_channel * 2,
795
+ initial_channel * 4,
796
+ kernel_size=5,
797
+ padding=2,
798
+ )
799
+ ),
800
+ norm_f(
801
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
802
+ ),
803
+ ]
804
+ )
805
+
806
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
807
+
808
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
809
+ x = self.pre(x)
810
+
811
+ fmap = []
812
+ for l in self.convs:
813
+ x = l(x)
814
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
815
+ fmap.append(x)
816
+ x = self.conv_post(x)
817
+ x = torch.flatten(x, 1, -1)
818
+
819
+ return x
820
+
821
+
822
+ class ReferenceEncoder(nn.Module):
823
+ """
824
+ inputs --- [N, Ty/r, n_mels*r] mels
825
+ outputs --- [N, ref_enc_gru_size]
826
+ """
827
+
828
+ def __init__(self, spec_channels: int, gin_channels: int = 0) -> None:
829
+ super().__init__()
830
+ self.spec_channels = spec_channels
831
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
832
+ K = len(ref_enc_filters)
833
+ filters = [1] + ref_enc_filters
834
+ convs = [
835
+ weight_norm(
836
+ nn.Conv2d(
837
+ in_channels=filters[i],
838
+ out_channels=filters[i + 1],
839
+ kernel_size=(3, 3),
840
+ stride=(2, 2),
841
+ padding=(1, 1),
842
+ )
843
+ )
844
+ for i in range(K)
845
+ ]
846
+ self.convs = nn.ModuleList(convs)
847
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
848
+
849
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
850
+ self.gru = nn.GRU(
851
+ input_size=ref_enc_filters[-1] * out_channels,
852
+ hidden_size=256 // 2,
853
+ batch_first=True,
854
+ )
855
+ self.proj = nn.Linear(128, gin_channels)
856
+
857
+ def forward(
858
+ self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None
859
+ ) -> torch.Tensor:
860
+ N = inputs.size(0)
861
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
862
+ for conv in self.convs:
863
+ out = conv(out)
864
+ # out = wn(out)
865
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
866
+
867
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
868
+ T = out.size(1)
869
+ N = out.size(0)
870
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
871
+
872
+ self.gru.flatten_parameters()
873
+ memory, out = self.gru(out) # out --- [1, N, 128]
874
+
875
+ return self.proj(out.squeeze(0))
876
+
877
+ def calculate_channels(
878
+ self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int
879
+ ) -> int:
880
+ for i in range(n_convs):
881
+ L = (L - kernel_size + 2 * pad) // stride + 1
882
+ return L
883
+
884
+
885
+ class SynthesizerTrn(nn.Module):
886
+ """
887
+ Synthesizer for Training
888
+ """
889
+
890
+ def __init__(
891
+ self,
892
+ n_vocab: int,
893
+ spec_channels: int,
894
+ segment_size: int,
895
+ inter_channels: int,
896
+ hidden_channels: int,
897
+ filter_channels: int,
898
+ n_heads: int,
899
+ n_layers: int,
900
+ kernel_size: int,
901
+ p_dropout: float,
902
+ resblock: str,
903
+ resblock_kernel_sizes: list[int],
904
+ resblock_dilation_sizes: list[list[int]],
905
+ upsample_rates: list[int],
906
+ upsample_initial_channel: int,
907
+ upsample_kernel_sizes: list[int],
908
+ n_speakers: int = 256,
909
+ gin_channels: int = 256,
910
+ use_sdp: bool = True,
911
+ n_flow_layer: int = 4,
912
+ n_layers_trans_flow: int = 6,
913
+ flow_share_parameter: bool = False,
914
+ use_transformer_flow: bool = True,
915
+ **kwargs: Any,
916
+ ) -> None:
917
+ super().__init__()
918
+ self.n_vocab = n_vocab
919
+ self.spec_channels = spec_channels
920
+ self.inter_channels = inter_channels
921
+ self.hidden_channels = hidden_channels
922
+ self.filter_channels = filter_channels
923
+ self.n_heads = n_heads
924
+ self.n_layers = n_layers
925
+ self.kernel_size = kernel_size
926
+ self.p_dropout = p_dropout
927
+ self.resblock = resblock
928
+ self.resblock_kernel_sizes = resblock_kernel_sizes
929
+ self.resblock_dilation_sizes = resblock_dilation_sizes
930
+ self.upsample_rates = upsample_rates
931
+ self.upsample_initial_channel = upsample_initial_channel
932
+ self.upsample_kernel_sizes = upsample_kernel_sizes
933
+ self.segment_size = segment_size
934
+ self.n_speakers = n_speakers
935
+ self.gin_channels = gin_channels
936
+ self.n_layers_trans_flow = n_layers_trans_flow
937
+ self.use_spk_conditioned_encoder = kwargs.get(
938
+ "use_spk_conditioned_encoder", True
939
+ )
940
+ self.use_sdp = use_sdp
941
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
942
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
943
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
944
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
945
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
946
+ self.enc_gin_channels = gin_channels
947
+ self.enc_p = TextEncoder(
948
+ n_vocab,
949
+ inter_channels,
950
+ hidden_channels,
951
+ filter_channels,
952
+ n_heads,
953
+ n_layers,
954
+ kernel_size,
955
+ p_dropout,
956
+ gin_channels=self.enc_gin_channels,
957
+ )
958
+ self.dec = Generator(
959
+ inter_channels,
960
+ resblock,
961
+ resblock_kernel_sizes,
962
+ resblock_dilation_sizes,
963
+ upsample_rates,
964
+ upsample_initial_channel,
965
+ upsample_kernel_sizes,
966
+ gin_channels=gin_channels,
967
+ )
968
+ self.enc_q = PosteriorEncoder(
969
+ spec_channels,
970
+ inter_channels,
971
+ hidden_channels,
972
+ 5,
973
+ 1,
974
+ 16,
975
+ gin_channels=gin_channels,
976
+ )
977
+ if use_transformer_flow:
978
+ self.flow = TransformerCouplingBlock(
979
+ inter_channels,
980
+ hidden_channels,
981
+ filter_channels,
982
+ n_heads,
983
+ n_layers_trans_flow,
984
+ 5,
985
+ p_dropout,
986
+ n_flow_layer,
987
+ gin_channels=gin_channels,
988
+ share_parameter=flow_share_parameter,
989
+ )
990
+ else:
991
+ self.flow = ResidualCouplingBlock(
992
+ inter_channels,
993
+ hidden_channels,
994
+ 5,
995
+ 1,
996
+ n_flow_layer,
997
+ gin_channels=gin_channels,
998
+ )
999
+ self.sdp = StochasticDurationPredictor(
1000
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
1001
+ )
1002
+ self.dp = DurationPredictor(
1003
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
1004
+ )
1005
+
1006
+ if n_speakers >= 1:
1007
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
1008
+ else:
1009
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
1010
+
1011
+ def forward(
1012
+ self,
1013
+ x: torch.Tensor,
1014
+ x_lengths: torch.Tensor,
1015
+ y: torch.Tensor,
1016
+ y_lengths: torch.Tensor,
1017
+ sid: torch.Tensor,
1018
+ tone: torch.Tensor,
1019
+ language: torch.Tensor,
1020
+ bert: torch.Tensor,
1021
+ style_vec: torch.Tensor,
1022
+ ) -> tuple[
1023
+ torch.Tensor,
1024
+ torch.Tensor,
1025
+ torch.Tensor,
1026
+ torch.Tensor,
1027
+ torch.Tensor,
1028
+ torch.Tensor,
1029
+ torch.Tensor,
1030
+ tuple[torch.Tensor, ...],
1031
+ tuple[torch.Tensor, ...],
1032
+ ]:
1033
+ if self.n_speakers > 0:
1034
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1035
+ else:
1036
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1037
+ x, m_p, logs_p, x_mask = self.enc_p(
1038
+ x, x_lengths, tone, language, bert, style_vec, g=g
1039
+ )
1040
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
1041
+ z_p = self.flow(z, y_mask, g=g)
1042
+
1043
+ with torch.no_grad():
1044
+ # negative cross-entropy
1045
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
1046
+ neg_cent1 = torch.sum(
1047
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
1048
+ ) # [b, 1, t_s]
1049
+ neg_cent2 = torch.matmul(
1050
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
1051
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
1052
+ neg_cent3 = torch.matmul(
1053
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
1054
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
1055
+ neg_cent4 = torch.sum(
1056
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
1057
+ ) # [b, 1, t_s]
1058
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
1059
+ if self.use_noise_scaled_mas:
1060
+ epsilon = (
1061
+ torch.std(neg_cent)
1062
+ * torch.randn_like(neg_cent)
1063
+ * self.current_mas_noise_scale
1064
+ )
1065
+ neg_cent = neg_cent + epsilon
1066
+
1067
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1068
+ attn = (
1069
+ monotonic_alignment.maximum_path(neg_cent, attn_mask.squeeze(1))
1070
+ .unsqueeze(1)
1071
+ .detach()
1072
+ )
1073
+
1074
+ w = attn.sum(2)
1075
+
1076
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
1077
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
1078
+
1079
+ logw_ = torch.log(w + 1e-6) * x_mask
1080
+ logw = self.dp(x, x_mask, g=g)
1081
+ # logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
1082
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
1083
+ x_mask
1084
+ ) # for averaging
1085
+ # l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
1086
+
1087
+ l_length = l_length_dp + l_length_sdp
1088
+
1089
+ # expand prior
1090
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
1091
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
1092
+
1093
+ z_slice, ids_slice = commons.rand_slice_segments(
1094
+ z, y_lengths, self.segment_size
1095
+ )
1096
+ o = self.dec(z_slice, g=g)
1097
+ return (
1098
+ o,
1099
+ l_length,
1100
+ attn,
1101
+ ids_slice,
1102
+ x_mask,
1103
+ y_mask,
1104
+ (z, z_p, m_p, logs_p, m_q, logs_q), # type: ignore
1105
+ (x, logw, logw_), # , logw_sdp),
1106
+ g,
1107
+ )
1108
+
1109
+ def infer(
1110
+ self,
1111
+ x: torch.Tensor,
1112
+ x_lengths: torch.Tensor,
1113
+ sid: torch.Tensor,
1114
+ tone: torch.Tensor,
1115
+ language: torch.Tensor,
1116
+ bert: torch.Tensor,
1117
+ style_vec: torch.Tensor,
1118
+ noise_scale: float = 0.667,
1119
+ length_scale: float = 1.0,
1120
+ noise_scale_w: float = 0.8,
1121
+ max_len: Optional[int] = None,
1122
+ sdp_ratio: float = 0.0,
1123
+ y: Optional[torch.Tensor] = None,
1124
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, ...]]:
1125
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
1126
+ # g = self.gst(y)
1127
+ if self.n_speakers > 0:
1128
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1129
+ else:
1130
+ assert y is not None
1131
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1132
+ x, m_p, logs_p, x_mask = self.enc_p(
1133
+ x, x_lengths, tone, language, bert, style_vec, g=g
1134
+ )
1135
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1136
+ sdp_ratio
1137
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1138
+ w = torch.exp(logw) * x_mask * length_scale
1139
+ w_ceil = torch.ceil(w)
1140
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1141
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1142
+ x_mask.dtype
1143
+ )
1144
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1145
+ attn = commons.generate_path(w_ceil, attn_mask)
1146
+
1147
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1148
+ 1, 2
1149
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1150
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1151
+ 1, 2
1152
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1153
+
1154
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1155
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1156
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1157
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
style_bert_vits2/models/modules.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Optional, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import remove_weight_norm, weight_norm
9
+
10
+ from style_bert_vits2.models import commons
11
+ from style_bert_vits2.models.attentions import Encoder
12
+ from style_bert_vits2.models.transforms import piecewise_rational_quadratic_transform
13
+
14
+
15
+ LRELU_SLOPE = 0.1
16
+
17
+
18
+ class LayerNorm(nn.Module):
19
+ def __init__(self, channels: int, eps: float = 1e-5) -> None:
20
+ super().__init__()
21
+ self.channels = channels
22
+ self.eps = eps
23
+
24
+ self.gamma = nn.Parameter(torch.ones(channels))
25
+ self.beta = nn.Parameter(torch.zeros(channels))
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ x = x.transpose(1, -1)
29
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
30
+ return x.transpose(1, -1)
31
+
32
+
33
+ class ConvReluNorm(nn.Module):
34
+ def __init__(
35
+ self,
36
+ in_channels: int,
37
+ hidden_channels: int,
38
+ out_channels: int,
39
+ kernel_size: int,
40
+ n_layers: int,
41
+ p_dropout: float,
42
+ ) -> None:
43
+ super().__init__()
44
+ self.in_channels = in_channels
45
+ self.hidden_channels = hidden_channels
46
+ self.out_channels = out_channels
47
+ self.kernel_size = kernel_size
48
+ self.n_layers = n_layers
49
+ self.p_dropout = p_dropout
50
+ assert n_layers > 1, "Number of layers should be larger than 0."
51
+
52
+ self.conv_layers = nn.ModuleList()
53
+ self.norm_layers = nn.ModuleList()
54
+ self.conv_layers.append(
55
+ nn.Conv1d(
56
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
57
+ )
58
+ )
59
+ self.norm_layers.append(LayerNorm(hidden_channels))
60
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
61
+ for _ in range(n_layers - 1):
62
+ self.conv_layers.append(
63
+ nn.Conv1d(
64
+ hidden_channels,
65
+ hidden_channels,
66
+ kernel_size,
67
+ padding=kernel_size // 2,
68
+ )
69
+ )
70
+ self.norm_layers.append(LayerNorm(hidden_channels))
71
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
72
+ self.proj.weight.data.zero_()
73
+ assert self.proj.bias is not None
74
+ self.proj.bias.data.zero_()
75
+
76
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
77
+ x_org = x
78
+ for i in range(self.n_layers):
79
+ x = self.conv_layers[i](x * x_mask)
80
+ x = self.norm_layers[i](x)
81
+ x = self.relu_drop(x)
82
+ x = x_org + self.proj(x)
83
+ return x * x_mask
84
+
85
+
86
+ class DDSConv(nn.Module):
87
+ """
88
+ Dialted and Depth-Separable Convolution
89
+ """
90
+
91
+ def __init__(
92
+ self, channels: int, kernel_size: int, n_layers: int, p_dropout: float = 0.0
93
+ ) -> None:
94
+ super().__init__()
95
+ self.channels = channels
96
+ self.kernel_size = kernel_size
97
+ self.n_layers = n_layers
98
+ self.p_dropout = p_dropout
99
+
100
+ self.drop = nn.Dropout(p_dropout)
101
+ self.convs_sep = nn.ModuleList()
102
+ self.convs_1x1 = nn.ModuleList()
103
+ self.norms_1 = nn.ModuleList()
104
+ self.norms_2 = nn.ModuleList()
105
+ for i in range(n_layers):
106
+ dilation = kernel_size**i
107
+ padding = (kernel_size * dilation - dilation) // 2
108
+ self.convs_sep.append(
109
+ nn.Conv1d(
110
+ channels,
111
+ channels,
112
+ kernel_size,
113
+ groups=channels,
114
+ dilation=dilation,
115
+ padding=padding,
116
+ )
117
+ )
118
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
119
+ self.norms_1.append(LayerNorm(channels))
120
+ self.norms_2.append(LayerNorm(channels))
121
+
122
+ def forward(
123
+ self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
124
+ ) -> torch.Tensor:
125
+ if g is not None:
126
+ x = x + g
127
+ for i in range(self.n_layers):
128
+ y = self.convs_sep[i](x * x_mask)
129
+ y = self.norms_1[i](y)
130
+ y = F.gelu(y)
131
+ y = self.convs_1x1[i](y)
132
+ y = self.norms_2[i](y)
133
+ y = F.gelu(y)
134
+ y = self.drop(y)
135
+ x = x + y
136
+ return x * x_mask
137
+
138
+
139
+ class WN(torch.nn.Module):
140
+ def __init__(
141
+ self,
142
+ hidden_channels: int,
143
+ kernel_size: int,
144
+ dilation_rate: int,
145
+ n_layers: int,
146
+ gin_channels: int = 0,
147
+ p_dropout: float = 0,
148
+ ) -> None:
149
+ super(WN, self).__init__()
150
+ assert kernel_size % 2 == 1
151
+ self.hidden_channels = hidden_channels
152
+ self.kernel_size = (kernel_size,)
153
+ self.dilation_rate = dilation_rate
154
+ self.n_layers = n_layers
155
+ self.gin_channels = gin_channels
156
+ self.p_dropout = p_dropout
157
+
158
+ self.in_layers = torch.nn.ModuleList()
159
+ self.res_skip_layers = torch.nn.ModuleList()
160
+ self.drop = nn.Dropout(p_dropout)
161
+
162
+ if gin_channels != 0:
163
+ cond_layer = torch.nn.Conv1d(
164
+ gin_channels, 2 * hidden_channels * n_layers, 1
165
+ )
166
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
167
+
168
+ for i in range(n_layers):
169
+ dilation = dilation_rate**i
170
+ padding = int((kernel_size * dilation - dilation) / 2)
171
+ in_layer = torch.nn.Conv1d(
172
+ hidden_channels,
173
+ 2 * hidden_channels,
174
+ kernel_size,
175
+ dilation=dilation,
176
+ padding=padding,
177
+ )
178
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
179
+ self.in_layers.append(in_layer)
180
+
181
+ # last one is not necessary
182
+ if i < n_layers - 1:
183
+ res_skip_channels = 2 * hidden_channels
184
+ else:
185
+ res_skip_channels = hidden_channels
186
+
187
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
188
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
189
+ self.res_skip_layers.append(res_skip_layer)
190
+
191
+ def forward(
192
+ self,
193
+ x: torch.Tensor,
194
+ x_mask: torch.Tensor,
195
+ g: Optional[torch.Tensor] = None,
196
+ **kwargs: Any,
197
+ ) -> torch.Tensor:
198
+ output = torch.zeros_like(x)
199
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
200
+
201
+ if g is not None:
202
+ g = self.cond_layer(g)
203
+
204
+ for i in range(self.n_layers):
205
+ x_in = self.in_layers[i](x)
206
+ if g is not None:
207
+ cond_offset = i * 2 * self.hidden_channels
208
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
209
+ else:
210
+ g_l = torch.zeros_like(x_in)
211
+
212
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
213
+ acts = self.drop(acts)
214
+
215
+ res_skip_acts = self.res_skip_layers[i](acts)
216
+ if i < self.n_layers - 1:
217
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
218
+ x = (x + res_acts) * x_mask
219
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
220
+ else:
221
+ output = output + res_skip_acts
222
+ return output * x_mask
223
+
224
+ def remove_weight_norm(self) -> None:
225
+ if self.gin_channels != 0:
226
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
227
+ for l in self.in_layers:
228
+ torch.nn.utils.remove_weight_norm(l)
229
+ for l in self.res_skip_layers:
230
+ torch.nn.utils.remove_weight_norm(l)
231
+
232
+
233
+ class ResBlock1(torch.nn.Module):
234
+ def __init__(
235
+ self,
236
+ channels: int,
237
+ kernel_size: int = 3,
238
+ dilation: tuple[int, int, int] = (1, 3, 5),
239
+ ) -> None:
240
+ super(ResBlock1, self).__init__()
241
+ self.convs1 = nn.ModuleList(
242
+ [
243
+ weight_norm(
244
+ Conv1d(
245
+ channels,
246
+ channels,
247
+ kernel_size,
248
+ 1,
249
+ dilation=dilation[0],
250
+ padding=commons.get_padding(kernel_size, dilation[0]),
251
+ )
252
+ ),
253
+ weight_norm(
254
+ Conv1d(
255
+ channels,
256
+ channels,
257
+ kernel_size,
258
+ 1,
259
+ dilation=dilation[1],
260
+ padding=commons.get_padding(kernel_size, dilation[1]),
261
+ )
262
+ ),
263
+ weight_norm(
264
+ Conv1d(
265
+ channels,
266
+ channels,
267
+ kernel_size,
268
+ 1,
269
+ dilation=dilation[2],
270
+ padding=commons.get_padding(kernel_size, dilation[2]),
271
+ )
272
+ ),
273
+ ]
274
+ )
275
+ self.convs1.apply(commons.init_weights)
276
+
277
+ self.convs2 = nn.ModuleList(
278
+ [
279
+ weight_norm(
280
+ Conv1d(
281
+ channels,
282
+ channels,
283
+ kernel_size,
284
+ 1,
285
+ dilation=1,
286
+ padding=commons.get_padding(kernel_size, 1),
287
+ )
288
+ ),
289
+ weight_norm(
290
+ Conv1d(
291
+ channels,
292
+ channels,
293
+ kernel_size,
294
+ 1,
295
+ dilation=1,
296
+ padding=commons.get_padding(kernel_size, 1),
297
+ )
298
+ ),
299
+ weight_norm(
300
+ Conv1d(
301
+ channels,
302
+ channels,
303
+ kernel_size,
304
+ 1,
305
+ dilation=1,
306
+ padding=commons.get_padding(kernel_size, 1),
307
+ )
308
+ ),
309
+ ]
310
+ )
311
+ self.convs2.apply(commons.init_weights)
312
+
313
+ def forward(
314
+ self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None
315
+ ) -> torch.Tensor:
316
+ for c1, c2 in zip(self.convs1, self.convs2):
317
+ xt = F.leaky_relu(x, LRELU_SLOPE)
318
+ if x_mask is not None:
319
+ xt = xt * x_mask
320
+ xt = c1(xt)
321
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
322
+ if x_mask is not None:
323
+ xt = xt * x_mask
324
+ xt = c2(xt)
325
+ x = xt + x
326
+ if x_mask is not None:
327
+ x = x * x_mask
328
+ return x
329
+
330
+ def remove_weight_norm(self) -> None:
331
+ for l in self.convs1:
332
+ remove_weight_norm(l)
333
+ for l in self.convs2:
334
+ remove_weight_norm(l)
335
+
336
+
337
+ class ResBlock2(torch.nn.Module):
338
+ def __init__(
339
+ self, channels: int, kernel_size: int = 3, dilation: tuple[int, int] = (1, 3)
340
+ ) -> None:
341
+ super(ResBlock2, self).__init__()
342
+ self.convs = nn.ModuleList(
343
+ [
344
+ weight_norm(
345
+ Conv1d(
346
+ channels,
347
+ channels,
348
+ kernel_size,
349
+ 1,
350
+ dilation=dilation[0],
351
+ padding=commons.get_padding(kernel_size, dilation[0]),
352
+ )
353
+ ),
354
+ weight_norm(
355
+ Conv1d(
356
+ channels,
357
+ channels,
358
+ kernel_size,
359
+ 1,
360
+ dilation=dilation[1],
361
+ padding=commons.get_padding(kernel_size, dilation[1]),
362
+ )
363
+ ),
364
+ ]
365
+ )
366
+ self.convs.apply(commons.init_weights)
367
+
368
+ def forward(
369
+ self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None
370
+ ) -> torch.Tensor:
371
+ for c in self.convs:
372
+ xt = F.leaky_relu(x, LRELU_SLOPE)
373
+ if x_mask is not None:
374
+ xt = xt * x_mask
375
+ xt = c(xt)
376
+ x = xt + x
377
+ if x_mask is not None:
378
+ x = x * x_mask
379
+ return x
380
+
381
+ def remove_weight_norm(self) -> None:
382
+ for l in self.convs:
383
+ remove_weight_norm(l)
384
+
385
+
386
+ class Log(nn.Module):
387
+ def forward(
388
+ self,
389
+ x: torch.Tensor,
390
+ x_mask: torch.Tensor,
391
+ reverse: bool = False,
392
+ **kwargs: Any,
393
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
394
+ if not reverse:
395
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
396
+ logdet = torch.sum(-y, [1, 2])
397
+ return y, logdet
398
+ else:
399
+ x = torch.exp(x) * x_mask
400
+ return x
401
+
402
+
403
+ class Flip(nn.Module):
404
+ def forward(
405
+ self,
406
+ x: torch.Tensor,
407
+ *args: Any,
408
+ reverse: bool = False,
409
+ **kwargs: Any,
410
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
411
+ x = torch.flip(x, [1])
412
+ if not reverse:
413
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
414
+ return x, logdet
415
+ else:
416
+ return x
417
+
418
+
419
+ class ElementwiseAffine(nn.Module):
420
+ def __init__(self, channels: int) -> None:
421
+ super().__init__()
422
+ self.channels = channels
423
+ self.m = nn.Parameter(torch.zeros(channels, 1))
424
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
425
+
426
+ def forward(
427
+ self,
428
+ x: torch.Tensor,
429
+ x_mask: torch.Tensor,
430
+ reverse: bool = False,
431
+ **kwargs: Any,
432
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
433
+ if not reverse:
434
+ y = self.m + torch.exp(self.logs) * x
435
+ y = y * x_mask
436
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
437
+ return y, logdet
438
+ else:
439
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
440
+ return x
441
+
442
+
443
+ class ResidualCouplingLayer(nn.Module):
444
+ def __init__(
445
+ self,
446
+ channels: int,
447
+ hidden_channels: int,
448
+ kernel_size: int,
449
+ dilation_rate: int,
450
+ n_layers: int,
451
+ p_dropout: float = 0,
452
+ gin_channels: int = 0,
453
+ mean_only: bool = False,
454
+ ) -> None:
455
+ assert channels % 2 == 0, "channels should be divisible by 2"
456
+ super().__init__()
457
+ self.channels = channels
458
+ self.hidden_channels = hidden_channels
459
+ self.kernel_size = kernel_size
460
+ self.dilation_rate = dilation_rate
461
+ self.n_layers = n_layers
462
+ self.half_channels = channels // 2
463
+ self.mean_only = mean_only
464
+
465
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
466
+ self.enc = WN(
467
+ hidden_channels,
468
+ kernel_size,
469
+ dilation_rate,
470
+ n_layers,
471
+ p_dropout=p_dropout,
472
+ gin_channels=gin_channels,
473
+ )
474
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
475
+ self.post.weight.data.zero_()
476
+ assert self.post.bias is not None
477
+ self.post.bias.data.zero_()
478
+
479
+ def forward(
480
+ self,
481
+ x: torch.Tensor,
482
+ x_mask: torch.Tensor,
483
+ g: Optional[torch.Tensor] = None,
484
+ reverse: bool = False,
485
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
486
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
487
+ h = self.pre(x0) * x_mask
488
+ h = self.enc(h, x_mask, g=g)
489
+ stats = self.post(h) * x_mask
490
+ if not self.mean_only:
491
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
492
+ else:
493
+ m = stats
494
+ logs = torch.zeros_like(m)
495
+
496
+ if not reverse:
497
+ x1 = m + x1 * torch.exp(logs) * x_mask
498
+ x = torch.cat([x0, x1], 1)
499
+ logdet = torch.sum(logs, [1, 2])
500
+ return x, logdet
501
+ else:
502
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
503
+ x = torch.cat([x0, x1], 1)
504
+ return x
505
+
506
+
507
+ class ConvFlow(nn.Module):
508
+ def __init__(
509
+ self,
510
+ in_channels: int,
511
+ filter_channels: int,
512
+ kernel_size: int,
513
+ n_layers: int,
514
+ num_bins: int = 10,
515
+ tail_bound: float = 5.0,
516
+ ) -> None:
517
+ super().__init__()
518
+ self.in_channels = in_channels
519
+ self.filter_channels = filter_channels
520
+ self.kernel_size = kernel_size
521
+ self.n_layers = n_layers
522
+ self.num_bins = num_bins
523
+ self.tail_bound = tail_bound
524
+ self.half_channels = in_channels // 2
525
+
526
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
527
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
528
+ self.proj = nn.Conv1d(
529
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
530
+ )
531
+ self.proj.weight.data.zero_()
532
+ assert self.proj.bias is not None
533
+ self.proj.bias.data.zero_()
534
+
535
+ def forward(
536
+ self,
537
+ x: torch.Tensor,
538
+ x_mask: torch.Tensor,
539
+ g: Optional[torch.Tensor] = None,
540
+ reverse: bool = False,
541
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
542
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
543
+ h = self.pre(x0)
544
+ h = self.convs(h, x_mask, g=g)
545
+ h = self.proj(h) * x_mask
546
+
547
+ b, c, t = x0.shape
548
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
549
+
550
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
551
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
552
+ self.filter_channels
553
+ )
554
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
555
+
556
+ x1, logabsdet = piecewise_rational_quadratic_transform(
557
+ x1,
558
+ unnormalized_widths,
559
+ unnormalized_heights,
560
+ unnormalized_derivatives,
561
+ inverse=reverse,
562
+ tails="linear",
563
+ tail_bound=self.tail_bound,
564
+ )
565
+
566
+ x = torch.cat([x0, x1], 1) * x_mask
567
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
568
+ if not reverse:
569
+ return x, logdet
570
+ else:
571
+ return x
572
+
573
+
574
+ class TransformerCouplingLayer(nn.Module):
575
+ def __init__(
576
+ self,
577
+ channels: int,
578
+ hidden_channels: int,
579
+ kernel_size: int,
580
+ n_layers: int,
581
+ n_heads: int,
582
+ p_dropout: float = 0,
583
+ filter_channels: int = 0,
584
+ mean_only: bool = False,
585
+ wn_sharing_parameter: Optional[nn.Module] = None,
586
+ gin_channels: int = 0,
587
+ ) -> None:
588
+ assert channels % 2 == 0, "channels should be divisible by 2"
589
+ super().__init__()
590
+ self.channels = channels
591
+ self.hidden_channels = hidden_channels
592
+ self.kernel_size = kernel_size
593
+ self.n_layers = n_layers
594
+ self.half_channels = channels // 2
595
+ self.mean_only = mean_only
596
+
597
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
598
+ self.enc = (
599
+ Encoder(
600
+ hidden_channels,
601
+ filter_channels,
602
+ n_heads,
603
+ n_layers,
604
+ kernel_size,
605
+ p_dropout,
606
+ isflow=True,
607
+ gin_channels=gin_channels,
608
+ )
609
+ if wn_sharing_parameter is None
610
+ else wn_sharing_parameter
611
+ )
612
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
613
+ self.post.weight.data.zero_()
614
+ assert self.post.bias is not None
615
+ self.post.bias.data.zero_()
616
+
617
+ def forward(
618
+ self,
619
+ x: torch.Tensor,
620
+ x_mask: torch.Tensor,
621
+ g: Optional[torch.Tensor] = None,
622
+ reverse: bool = False,
623
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
624
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
625
+ h = self.pre(x0) * x_mask
626
+ h = self.enc(h, x_mask, g=g)
627
+ stats = self.post(h) * x_mask
628
+ if not self.mean_only:
629
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
630
+ else:
631
+ m = stats
632
+ logs = torch.zeros_like(m)
633
+
634
+ if not reverse:
635
+ x1 = m + x1 * torch.exp(logs) * x_mask
636
+ x = torch.cat([x0, x1], 1)
637
+ logdet = torch.sum(logs, [1, 2])
638
+ return x, logdet
639
+ else:
640
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
641
+ x = torch.cat([x0, x1], 1)
642
+ return x
style_bert_vits2/models/monotonic_alignment.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 以下に記述されている関数のコメントはリファクタリング時に GPT-4 に生成させたもので、
3
+ コードと完全に一致している保証はない。あくまで参考程度とすること。
4
+ """
5
+
6
+ from typing import Any
7
+
8
+ import numba
9
+ import torch
10
+ from numpy import float32, int32, zeros
11
+
12
+
13
+ def maximum_path(neg_cent: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
14
+ """
15
+ 与えられた負の中心とマスクを使用して最大パスを計算する
16
+
17
+ Args:
18
+ neg_cent (torch.Tensor): 負の中心を表すテンソル
19
+ mask (torch.Tensor): マスクを表すテンソル
20
+
21
+ Returns:
22
+ Tensor: 計算された最大パスを表すテンソル
23
+ """
24
+
25
+ device = neg_cent.device
26
+ dtype = neg_cent.dtype
27
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
28
+ path = zeros(neg_cent.shape, dtype=int32)
29
+
30
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
31
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
32
+ __maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
33
+
34
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
35
+
36
+
37
+ @numba.jit(
38
+ numba.void(
39
+ numba.int32[:, :, ::1],
40
+ numba.float32[:, :, ::1],
41
+ numba.int32[::1],
42
+ numba.int32[::1],
43
+ ),
44
+ nopython=True,
45
+ nogil=True,
46
+ ) # type: ignore
47
+ def __maximum_path_jit(paths: Any, values: Any, t_ys: Any, t_xs: Any) -> None:
48
+ """
49
+ 与えられたパス、値、およびターゲットの y と x 座標を使用して JIT で最大パスを計算する
50
+
51
+ Args:
52
+ paths: 計算されたパスを格納するための整数型の 3 次元配列
53
+ values: 値を格納するための浮動小数点型の 3 次元配列
54
+ t_ys: ターゲットの y 座標を格納するための整数型の 1 次元配列
55
+ t_xs: ターゲットの x 座標を格納するための整数型の 1 次元配列
56
+ """
57
+
58
+ b = paths.shape[0]
59
+ max_neg_val = -1e9
60
+ for i in range(int(b)):
61
+ path = paths[i]
62
+ value = values[i]
63
+ t_y = t_ys[i]
64
+ t_x = t_xs[i]
65
+
66
+ v_prev = v_cur = 0.0
67
+ index = t_x - 1
68
+
69
+ for y in range(t_y):
70
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
71
+ if x == y:
72
+ v_cur = max_neg_val
73
+ else:
74
+ v_cur = value[y - 1, x]
75
+ if x == 0:
76
+ if y == 0:
77
+ v_prev = 0.0
78
+ else:
79
+ v_prev = max_neg_val
80
+ else:
81
+ v_prev = value[y - 1, x - 1]
82
+ value[y, x] += max(v_prev, v_cur)
83
+
84
+ for y in range(t_y - 1, -1, -1):
85
+ path[y, index] = 1
86
+ if index != 0 and (
87
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
88
+ ):
89
+ index = index - 1
style_bert_vits2/models/transforms.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+
8
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
9
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
10
+ DEFAULT_MIN_DERIVATIVE = 1e-3
11
+
12
+
13
+ def piecewise_rational_quadratic_transform(
14
+ inputs: torch.Tensor,
15
+ unnormalized_widths: torch.Tensor,
16
+ unnormalized_heights: torch.Tensor,
17
+ unnormalized_derivatives: torch.Tensor,
18
+ inverse: bool = False,
19
+ tails: Optional[str] = None,
20
+ tail_bound: float = 1.0,
21
+ min_bin_width: float = DEFAULT_MIN_BIN_WIDTH,
22
+ min_bin_height: float = DEFAULT_MIN_BIN_HEIGHT,
23
+ min_derivative: float = DEFAULT_MIN_DERIVATIVE,
24
+ ) -> tuple[torch.Tensor, torch.Tensor]:
25
+
26
+ if tails is None:
27
+ spline_fn = rational_quadratic_spline
28
+ spline_kwargs = {}
29
+ else:
30
+ spline_fn = unconstrained_rational_quadratic_spline
31
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
32
+
33
+ outputs, logabsdet = spline_fn(
34
+ inputs=inputs,
35
+ unnormalized_widths=unnormalized_widths,
36
+ unnormalized_heights=unnormalized_heights,
37
+ unnormalized_derivatives=unnormalized_derivatives,
38
+ inverse=inverse,
39
+ min_bin_width=min_bin_width,
40
+ min_bin_height=min_bin_height,
41
+ min_derivative=min_derivative,
42
+ **spline_kwargs, # type: ignore
43
+ )
44
+ return outputs, logabsdet
45
+
46
+
47
+ def searchsorted(
48
+ bin_locations: torch.Tensor, inputs: torch.Tensor, eps: float = 1e-6
49
+ ) -> torch.Tensor:
50
+ bin_locations[..., -1] += eps
51
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
52
+
53
+
54
+ def unconstrained_rational_quadratic_spline(
55
+ inputs: torch.Tensor,
56
+ unnormalized_widths: torch.Tensor,
57
+ unnormalized_heights: torch.Tensor,
58
+ unnormalized_derivatives: torch.Tensor,
59
+ inverse: bool = False,
60
+ tails: str = "linear",
61
+ tail_bound: float = 1.0,
62
+ min_bin_width: float = DEFAULT_MIN_BIN_WIDTH,
63
+ min_bin_height: float = DEFAULT_MIN_BIN_HEIGHT,
64
+ min_derivative: float = DEFAULT_MIN_DERIVATIVE,
65
+ ) -> tuple[torch.Tensor, torch.Tensor]:
66
+
67
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
68
+ outside_interval_mask = ~inside_interval_mask
69
+
70
+ outputs = torch.zeros_like(inputs)
71
+ logabsdet = torch.zeros_like(inputs)
72
+
73
+ if tails == "linear":
74
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
75
+ constant = np.log(np.exp(1 - min_derivative) - 1)
76
+ unnormalized_derivatives[..., 0] = constant
77
+ unnormalized_derivatives[..., -1] = constant
78
+
79
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
80
+ logabsdet[outside_interval_mask] = 0
81
+ else:
82
+ raise RuntimeError(f"{tails} tails are not implemented.")
83
+
84
+ (
85
+ outputs[inside_interval_mask],
86
+ logabsdet[inside_interval_mask],
87
+ ) = rational_quadratic_spline(
88
+ inputs=inputs[inside_interval_mask],
89
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
90
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
91
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
92
+ inverse=inverse,
93
+ left=-tail_bound,
94
+ right=tail_bound,
95
+ bottom=-tail_bound,
96
+ top=tail_bound,
97
+ min_bin_width=min_bin_width,
98
+ min_bin_height=min_bin_height,
99
+ min_derivative=min_derivative,
100
+ )
101
+
102
+ return outputs, logabsdet
103
+
104
+
105
+ def rational_quadratic_spline(
106
+ inputs: torch.Tensor,
107
+ unnormalized_widths: torch.Tensor,
108
+ unnormalized_heights: torch.Tensor,
109
+ unnormalized_derivatives: torch.Tensor,
110
+ inverse: bool = False,
111
+ left: float = 0.0,
112
+ right: float = 1.0,
113
+ bottom: float = 0.0,
114
+ top: float = 1.0,
115
+ min_bin_width: float = DEFAULT_MIN_BIN_WIDTH,
116
+ min_bin_height: float = DEFAULT_MIN_BIN_HEIGHT,
117
+ min_derivative: float = DEFAULT_MIN_DERIVATIVE,
118
+ ) -> tuple[torch.Tensor, torch.Tensor]:
119
+
120
+ if torch.min(inputs) < left or torch.max(inputs) > right:
121
+ raise ValueError("Input to a transform is not within its domain")
122
+
123
+ num_bins = unnormalized_widths.shape[-1]
124
+
125
+ if min_bin_width * num_bins > 1.0:
126
+ raise ValueError("Minimal bin width too large for the number of bins")
127
+ if min_bin_height * num_bins > 1.0:
128
+ raise ValueError("Minimal bin height too large for the number of bins")
129
+
130
+ widths = F.softmax(unnormalized_widths, dim=-1)
131
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
132
+ cumwidths = torch.cumsum(widths, dim=-1)
133
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
134
+ cumwidths = (right - left) * cumwidths + left
135
+ cumwidths[..., 0] = left
136
+ cumwidths[..., -1] = right
137
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
138
+
139
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
140
+
141
+ heights = F.softmax(unnormalized_heights, dim=-1)
142
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
143
+ cumheights = torch.cumsum(heights, dim=-1)
144
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
145
+ cumheights = (top - bottom) * cumheights + bottom
146
+ cumheights[..., 0] = bottom
147
+ cumheights[..., -1] = top
148
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
149
+
150
+ if inverse:
151
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
152
+ else:
153
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
154
+
155
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
156
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
157
+
158
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
159
+ delta = heights / widths
160
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
161
+
162
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
163
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
164
+
165
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
166
+
167
+ if inverse:
168
+ a = (inputs - input_cumheights) * (
169
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
170
+ ) + input_heights * (input_delta - input_derivatives)
171
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
172
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
173
+ )
174
+ c = -input_delta * (inputs - input_cumheights)
175
+
176
+ discriminant = b.pow(2) - 4 * a * c
177
+ assert (discriminant >= 0).all()
178
+
179
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
180
+ outputs = root * input_bin_widths + input_cumwidths
181
+
182
+ theta_one_minus_theta = root * (1 - root)
183
+ denominator = input_delta + (
184
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185
+ * theta_one_minus_theta
186
+ )
187
+ derivative_numerator = input_delta.pow(2) * (
188
+ input_derivatives_plus_one * root.pow(2)
189
+ + 2 * input_delta * theta_one_minus_theta
190
+ + input_derivatives * (1 - root).pow(2)
191
+ )
192
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
193
+
194
+ return outputs, -logabsdet
195
+ else:
196
+ theta = (inputs - input_cumwidths) / input_bin_widths
197
+ theta_one_minus_theta = theta * (1 - theta)
198
+
199
+ numerator = input_heights * (
200
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
201
+ )
202
+ denominator = input_delta + (
203
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
204
+ * theta_one_minus_theta
205
+ )
206
+ outputs = input_cumheights + numerator / denominator
207
+
208
+ derivative_numerator = input_delta.pow(2) * (
209
+ input_derivatives_plus_one * theta.pow(2)
210
+ + 2 * input_delta * theta_one_minus_theta
211
+ + input_derivatives * (1 - theta).pow(2)
212
+ )
213
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
214
+
215
+ return outputs, logabsdet
style_bert_vits2/models/utils/__init__.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import logging
3
+ import os
4
+ import re
5
+ import subprocess
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Optional, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from numpy.typing import NDArray
12
+
13
+ from style_bert_vits2.logging import logger
14
+ from style_bert_vits2.models.utils import checkpoints # type: ignore
15
+ from style_bert_vits2.models.utils import safetensors # type: ignore
16
+
17
+
18
+ if TYPE_CHECKING:
19
+ # tensorboard はライブラリとしてインストールされている場合は依存関係に含まれないため、型チェック時のみインポートする
20
+ from torch.utils.tensorboard import SummaryWriter
21
+
22
+
23
+ __is_matplotlib_imported = False
24
+
25
+
26
+ def summarize(
27
+ writer: "SummaryWriter",
28
+ global_step: int,
29
+ scalars: dict[str, float] = {},
30
+ histograms: dict[str, Any] = {},
31
+ images: dict[str, Any] = {},
32
+ audios: dict[str, Any] = {},
33
+ audio_sampling_rate: int = 22050,
34
+ ) -> None:
35
+ """
36
+ 指定されたデータを TensorBoard にまとめて追加する
37
+
38
+ Args:
39
+ writer (SummaryWriter): TensorBoard への書き込みを行うオブジェクト
40
+ global_step (int): グローバルステップ数
41
+ scalars (dict[str, float]): スカラー値の辞書
42
+ histograms (dict[str, Any]): ヒストグラムの辞書
43
+ images (dict[str, Any]): 画像データの辞書
44
+ audios (dict[str, Any]): 音声データの辞書
45
+ audio_sampling_rate (int): 音声データのサンプリングレート
46
+ """
47
+ for k, v in scalars.items():
48
+ writer.add_scalar(k, v, global_step)
49
+ for k, v in histograms.items():
50
+ writer.add_histogram(k, v, global_step)
51
+ for k, v in images.items():
52
+ writer.add_image(k, v, global_step, dataformats="HWC")
53
+ for k, v in audios.items():
54
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
55
+
56
+
57
+ def is_resuming(dir_path: Union[str, Path]) -> bool:
58
+ """
59
+ 指定されたディレクトリパスに再開可能なモデルが存在するかどうかを返す
60
+
61
+ Args:
62
+ dir_path: チェックするディレクトリのパス
63
+
64
+ Returns:
65
+ bool: 再開可能なモデルが存在するかどうか
66
+ """
67
+ # JP-ExtraバージョンではDURがなくWDがあったり変わるため、Gのみで判断する
68
+ g_list = glob.glob(os.path.join(dir_path, "G_*.pth"))
69
+ # d_list = glob.glob(os.path.join(dir_path, "D_*.pth"))
70
+ # dur_list = glob.glob(os.path.join(dir_path, "DUR_*.pth"))
71
+ return len(g_list) > 0
72
+
73
+
74
+ def plot_spectrogram_to_numpy(spectrogram: NDArray[Any]) -> NDArray[Any]:
75
+ """
76
+ 指定されたスペクトログラムを画像データに変換する
77
+
78
+ Args:
79
+ spectrogram (NDArray[Any]): スペクトログラム
80
+
81
+ Returns:
82
+ NDArray[Any]: 画像データ
83
+ """
84
+
85
+ global __is_matplotlib_imported
86
+ if not __is_matplotlib_imported:
87
+ import matplotlib
88
+
89
+ matplotlib.use("Agg")
90
+ __is_matplotlib_imported = True
91
+ mpl_logger = logging.getLogger("matplotlib")
92
+ mpl_logger.setLevel(logging.WARNING)
93
+ import matplotlib.pylab as plt
94
+ import numpy as np
95
+
96
+ fig, ax = plt.subplots(figsize=(10, 2))
97
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
98
+ plt.colorbar(im, ax=ax)
99
+ plt.xlabel("Frames")
100
+ plt.ylabel("Channels")
101
+ plt.tight_layout()
102
+
103
+ fig.canvas.draw()
104
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") # type: ignore
105
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
106
+ plt.close()
107
+ return data
108
+
109
+
110
+ def plot_alignment_to_numpy(
111
+ alignment: NDArray[Any], info: Optional[str] = None
112
+ ) -> NDArray[Any]:
113
+ """
114
+ 指定されたアライメントを画像データに変換する
115
+
116
+ Args:
117
+ alignment (NDArray[Any]): アライメント
118
+ info (Optional[str]): 画像に追加する情報
119
+
120
+ Returns:
121
+ NDArray[Any]: 画像データ
122
+ """
123
+
124
+ global __is_matplotlib_imported
125
+ if not __is_matplotlib_imported:
126
+ import matplotlib
127
+
128
+ matplotlib.use("Agg")
129
+ __is_matplotlib_imported = True
130
+ mpl_logger = logging.getLogger("matplotlib")
131
+ mpl_logger.setLevel(logging.WARNING)
132
+ import matplotlib.pylab as plt
133
+
134
+ fig, ax = plt.subplots(figsize=(6, 4))
135
+ im = ax.imshow(
136
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
137
+ )
138
+ fig.colorbar(im, ax=ax)
139
+ xlabel = "Decoder timestep"
140
+ if info is not None:
141
+ xlabel += "\n\n" + info
142
+ plt.xlabel(xlabel)
143
+ plt.ylabel("Encoder timestep")
144
+ plt.tight_layout()
145
+
146
+ fig.canvas.draw()
147
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") # type: ignore
148
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
149
+ plt.close()
150
+ return data
151
+
152
+
153
+ def load_wav_to_torch(full_path: Union[str, Path]) -> tuple[torch.FloatTensor, int]:
154
+ """
155
+ 指定された音声ファイルを読み込み、PyTorch のテンソルに変換して返す
156
+
157
+ Args:
158
+ full_path (Union[str, Path]): 音声ファイルのパス
159
+
160
+ Returns:
161
+ tuple[torch.FloatTensor, int]: 音声データのテンソルとサンプリングレート
162
+ """
163
+
164
+ # この関数は学習時以外使われないため、ライブラリとしての style_bert_vits2 が
165
+ # 重たい scipy に依存しないように遅延 import する
166
+ try:
167
+ from scipy.io.wavfile import read
168
+ except ImportError:
169
+ raise ImportError("scipy is required to load wav file")
170
+
171
+ sampling_rate, data = read(full_path)
172
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
173
+
174
+
175
+ def load_filepaths_and_text(
176
+ filename: Union[str, Path], split: str = "|"
177
+ ) -> list[list[str]]:
178
+ """
179
+ 指定されたファイルからファイルパスとテキストを読み込む
180
+
181
+ Args:
182
+ filename (Union[str, Path]): ファイルのパス
183
+ split (str): ファイルの区切り文字 (デフォルト: "|")
184
+
185
+ Returns:
186
+ list[list[str]]: ファイルパスとテキストのリスト
187
+ """
188
+
189
+ with open(filename, encoding="utf-8") as f:
190
+ filepaths_and_text = [line.strip().split(split) for line in f]
191
+ return filepaths_and_text
192
+
193
+
194
+ def get_logger(
195
+ model_dir_path: Union[str, Path], filename: str = "train.log"
196
+ ) -> logging.Logger:
197
+ """
198
+ ロガーを取得する
199
+
200
+ Args:
201
+ model_dir_path (Union[str, Path]): ログを保存するディレクトリのパス
202
+ filename (str): ログファイルの名前 (デフォルト: "train.log")
203
+
204
+ Returns:
205
+ logging.Logger: ロガー
206
+ """
207
+
208
+ global logger
209
+ logger = logging.getLogger(os.path.basename(model_dir_path))
210
+ logger.setLevel(logging.DEBUG)
211
+
212
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
213
+ if not os.path.exists(model_dir_path):
214
+ os.makedirs(model_dir_path)
215
+ h = logging.FileHandler(os.path.join(model_dir_path, filename))
216
+ h.setLevel(logging.DEBUG)
217
+ h.setFormatter(formatter)
218
+ logger.addHandler(h)
219
+ return logger
220
+
221
+
222
+ def get_steps(model_path: Union[str, Path]) -> Optional[int]:
223
+ """
224
+ モデルのパスからイテレーション回数を取得する
225
+
226
+ Args:
227
+ model_path (Union[str, Path]): モデルのパス
228
+
229
+ Returns:
230
+ Optional[int]: イテレーション回数
231
+ """
232
+
233
+ matches = re.findall(r"\d+", model_path) # type: ignore
234
+ return matches[-1] if matches else None
235
+
236
+
237
+ def check_git_hash(model_dir_path: Union[str, Path]) -> None:
238
+ """
239
+ モデルのディレクトリに .git ディレクトリが存在する場合、ハッシュ値を比較する
240
+
241
+ Args:
242
+ model_dir_path (Union[str, Path]): モデルのディレクトリのパス
243
+ """
244
+
245
+ source_dir = os.path.dirname(os.path.realpath(__file__))
246
+ if not os.path.exists(os.path.join(source_dir, ".git")):
247
+ logger.warning(
248
+ f"{source_dir} is not a git repository, therefore hash value comparison will be ignored."
249
+ )
250
+ return
251
+
252
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
253
+
254
+ path = os.path.join(model_dir_path, "githash")
255
+ if os.path.exists(path):
256
+ with open(path, encoding="utf-8") as f:
257
+ saved_hash = f.read()
258
+ if saved_hash != cur_hash:
259
+ logger.warning(
260
+ f"git hash values are different. {saved_hash[:8]}(saved) != {cur_hash[:8]}(current)"
261
+ )
262
+ else:
263
+ with open(path, "w", encoding="utf-8") as f:
264
+ f.write(cur_hash)
style_bert_vits2/models/utils/checkpoints.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Any, Optional, Union
6
+
7
+ import torch
8
+
9
+ from style_bert_vits2.logging import logger
10
+
11
+
12
+ def load_checkpoint(
13
+ checkpoint_path: Union[str, Path],
14
+ model: torch.nn.Module,
15
+ optimizer: Optional[torch.optim.Optimizer] = None,
16
+ skip_optimizer: bool = False,
17
+ for_infer: bool = False,
18
+ ) -> tuple[torch.nn.Module, Optional[torch.optim.Optimizer], float, int]:
19
+ """
20
+ 指定されたパスからチェックポイントを読み込み、モデルとオプティマイザーを更新する。
21
+
22
+ Args:
23
+ checkpoint_path (Union[str, Path]): チェックポイントファイルのパス
24
+ model (torch.nn.Module): 更新するモデル
25
+ optimizer (Optional[torch.optim.Optimizer]): 更新するオプティマイザー。None の場合は更新しない
26
+ skip_optimizer (bool): オプティマイザーの更新をスキップするかどうかのフラグ
27
+ for_infer (bool): 推論用に読み込むかどうかのフラグ
28
+
29
+ Returns:
30
+ tuple[torch.nn.Module, Optional[torch.optim.Optimizer], float, int]: 更新されたモデルとオプティマイザー、学習率、イテレーション回数
31
+ """
32
+
33
+ assert os.path.isfile(checkpoint_path)
34
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
35
+ iteration = checkpoint_dict["iteration"]
36
+ learning_rate = checkpoint_dict["learning_rate"]
37
+ logger.info(
38
+ f"Loading model and optimizer at iteration {iteration} from {checkpoint_path}"
39
+ )
40
+ if (
41
+ optimizer is not None
42
+ and not skip_optimizer
43
+ and checkpoint_dict["optimizer"] is not None
44
+ ):
45
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
46
+ elif optimizer is None and not skip_optimizer:
47
+ # else: Disable this line if Infer and resume checkpoint,then enable the line upper
48
+ new_opt_dict = optimizer.state_dict() # type: ignore
49
+ new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
50
+ new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
51
+ new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
52
+ optimizer.load_state_dict(new_opt_dict) # type: ignore
53
+
54
+ saved_state_dict = checkpoint_dict["model"]
55
+ if hasattr(model, "module"):
56
+ state_dict = model.module.state_dict()
57
+ else:
58
+ state_dict = model.state_dict()
59
+
60
+ new_state_dict = {}
61
+ for k, v in state_dict.items():
62
+ try:
63
+ # assert "emb_g" not in k
64
+ new_state_dict[k] = saved_state_dict[k]
65
+ assert saved_state_dict[k].shape == v.shape, (
66
+ saved_state_dict[k].shape,
67
+ v.shape,
68
+ )
69
+ except:
70
+ # For upgrading from the old version
71
+ if "ja_bert_proj" in k:
72
+ v = torch.zeros_like(v)
73
+ logger.warning(
74
+ f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
75
+ )
76
+ elif "enc_q" in k and for_infer:
77
+ continue
78
+ else:
79
+ logger.error(f"{k} is not in the checkpoint {checkpoint_path}")
80
+
81
+ new_state_dict[k] = v
82
+
83
+ if hasattr(model, "module"):
84
+ model.module.load_state_dict(new_state_dict, strict=False)
85
+ else:
86
+ model.load_state_dict(new_state_dict, strict=False)
87
+
88
+ logger.info(f"Loaded '{checkpoint_path}' (iteration {iteration})")
89
+
90
+ return model, optimizer, learning_rate, iteration
91
+
92
+
93
+ def save_checkpoint(
94
+ model: torch.nn.Module,
95
+ optimizer: Union[torch.optim.Optimizer, torch.optim.AdamW],
96
+ learning_rate: float,
97
+ iteration: int,
98
+ checkpoint_path: Union[str, Path],
99
+ ) -> None:
100
+ """
101
+ モデルとオプティマイザーの状態を指定されたパスに保存する。
102
+
103
+ Args:
104
+ model (torch.nn.Module): 保存するモデル
105
+ optimizer (Union[torch.optim.Optimizer, torch.optim.AdamW]): 保存するオプティマイザー
106
+ learning_rate (float): 学習率
107
+ iteration (int): イテレーション回数
108
+ checkpoint_path (Union[str, Path]): 保存先のパス
109
+ """
110
+ logger.info(
111
+ f"Saving model and optimizer state at iteration {iteration} to {checkpoint_path}"
112
+ )
113
+ if hasattr(model, "module"):
114
+ state_dict = model.module.state_dict()
115
+ else:
116
+ state_dict = model.state_dict()
117
+ torch.save(
118
+ {
119
+ "model": state_dict,
120
+ "iteration": iteration,
121
+ "optimizer": optimizer.state_dict(),
122
+ "learning_rate": learning_rate,
123
+ },
124
+ checkpoint_path,
125
+ )
126
+
127
+
128
+ def clean_checkpoints(
129
+ model_dir_path: Union[str, Path] = "logs/44k/",
130
+ n_ckpts_to_keep: int = 2,
131
+ sort_by_time: bool = True,
132
+ ) -> None:
133
+ """
134
+ 指定されたディレクトリから古いチェックポイントを削除して空き容量を確保する
135
+
136
+ Args:
137
+ model_dir_path (Union[str, Path]): モデルが保存されているディレクトリのパス
138
+ n_ckpts_to_keep (int): 保持するチェックポイントの数(G_0.pth と D_0.pth を除く)
139
+ sort_by_time (bool): True の場合、時間順に削除。False の場合、名前順に削除
140
+ """
141
+
142
+ ckpts_files = [
143
+ f
144
+ for f in os.listdir(model_dir_path)
145
+ if os.path.isfile(os.path.join(model_dir_path, f))
146
+ ]
147
+
148
+ def name_key(_f: str) -> int:
149
+ return int(re.compile("._(\\d+)\\.pth").match(_f).group(1)) # type: ignore
150
+
151
+ def time_key(_f: str) -> float:
152
+ return os.path.getmtime(os.path.join(model_dir_path, _f))
153
+
154
+ sort_key = time_key if sort_by_time else name_key
155
+
156
+ def x_sorted(_x: str) -> list[str]:
157
+ return sorted(
158
+ [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
159
+ key=sort_key,
160
+ )
161
+
162
+ to_del = [
163
+ os.path.join(model_dir_path, fn)
164
+ for fn in (
165
+ x_sorted("G_")[:-n_ckpts_to_keep]
166
+ + x_sorted("D_")[:-n_ckpts_to_keep]
167
+ + x_sorted("WD_")[:-n_ckpts_to_keep]
168
+ + x_sorted("DUR_")[:-n_ckpts_to_keep]
169
+ )
170
+ ]
171
+
172
+ def del_info(fn: str) -> None:
173
+ return logger.info(f"Free up space by deleting ckpt {fn}")
174
+
175
+ def del_routine(x: str) -> list[Any]:
176
+ return [os.remove(x), del_info(x)]
177
+
178
+ [del_routine(fn) for fn in to_del]
179
+
180
+
181
+ def get_latest_checkpoint_path(
182
+ model_dir_path: Union[str, Path], regex: str = "G_*.pth"
183
+ ) -> str:
184
+ """
185
+ 指定されたディレクトリから最新のチェックポイントのパスを取得する
186
+
187
+ Args:
188
+ model_dir_path (Union[str, Path]): モデルが保存されているディレクトリのパス
189
+ regex (str): チェックポイントのファイル名の正規表現
190
+
191
+ Returns:
192
+ str: 最新のチェックポイントのパス
193
+ """
194
+
195
+ f_list = glob.glob(os.path.join(str(model_dir_path), regex))
196
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
197
+ try:
198
+ x = f_list[-1]
199
+ except IndexError:
200
+ raise ValueError(f"No checkpoint found in {model_dir_path} with regex {regex}")
201
+
202
+ return x
style_bert_vits2/models/utils/safetensors.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any, Optional, Union
3
+
4
+ import torch
5
+ from safetensors import safe_open
6
+ from safetensors.torch import save_file
7
+
8
+ from style_bert_vits2.logging import logger
9
+
10
+
11
+ def load_safetensors(
12
+ checkpoint_path: Union[str, Path],
13
+ model: torch.nn.Module,
14
+ for_infer: bool = False,
15
+ ) -> tuple[torch.nn.Module, Optional[int]]:
16
+ """
17
+ 指定されたパスから safetensors モデルを読み込み、モデルとイテレーションを返す。
18
+
19
+ Args:
20
+ checkpoint_path (Union[str, Path]): モデルのチェックポイントファイルのパス
21
+ model (torch.nn.Module): 読み込む対象のモデル
22
+ for_infer (bool): 推論用に読み込むかどうかのフラグ
23
+
24
+ Returns:
25
+ tuple[torch.nn.Module, Optional[int]]: 読み込まれたモデルとイテレーション回数(存在する場合)
26
+ """
27
+
28
+ tensors: dict[str, Any] = {}
29
+ iteration: Optional[int] = None
30
+ with safe_open(str(checkpoint_path), framework="pt", device="cpu") as f: # type: ignore
31
+ for key in f.keys():
32
+ if key == "iteration":
33
+ iteration = f.get_tensor(key).item()
34
+ tensors[key] = f.get_tensor(key)
35
+ if hasattr(model, "module"):
36
+ result = model.module.load_state_dict(tensors, strict=False)
37
+ else:
38
+ result = model.load_state_dict(tensors, strict=False)
39
+ for key in result.missing_keys:
40
+ if key.startswith("enc_q") and for_infer:
41
+ continue
42
+ logger.warning(f"Missing key: {key}")
43
+ for key in result.unexpected_keys:
44
+ if key == "iteration":
45
+ continue
46
+ logger.warning(f"Unexpected key: {key}")
47
+ if iteration is None:
48
+ logger.info(f"Loaded '{checkpoint_path}'")
49
+ else:
50
+ logger.info(f"Loaded '{checkpoint_path}' (iteration {iteration})")
51
+
52
+ return model, iteration
53
+
54
+
55
+ def save_safetensors(
56
+ model: torch.nn.Module,
57
+ iteration: int,
58
+ checkpoint_path: Union[str, Path],
59
+ is_half: bool = False,
60
+ for_infer: bool = False,
61
+ ) -> None:
62
+ """
63
+ モデルを safetensors 形式で保存する。
64
+
65
+ Args:
66
+ model (torch.nn.Module): 保存するモデル
67
+ iteration (int): イテレーション回数
68
+ checkpoint_path (Union[str, Path]): 保存先のパス
69
+ is_half (bool): モデルを半精度で保存するかどうかのフラグ
70
+ for_infer (bool): 推論用に保存するかどうかのフラグ
71
+ """
72
+
73
+ if hasattr(model, "module"):
74
+ state_dict = model.module.state_dict()
75
+ else:
76
+ state_dict = model.state_dict()
77
+ keys = []
78
+ for k in state_dict:
79
+ if "enc_q" in k and for_infer:
80
+ continue
81
+ keys.append(k)
82
+
83
+ new_dict = (
84
+ {k: state_dict[k].half() for k in keys}
85
+ if is_half
86
+ else {k: state_dict[k] for k in keys}
87
+ )
88
+ new_dict["iteration"] = torch.LongTensor([iteration])
89
+ logger.info(f"Saved safetensors to {checkpoint_path}")
90
+
91
+ save_file(new_dict, checkpoint_path)
style_bert_vits2/nlp/__init__.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TYPE_CHECKING, Optional
2
+
3
+ from style_bert_vits2.constants import Languages
4
+ from style_bert_vits2.nlp.symbols import (
5
+ LANGUAGE_ID_MAP,
6
+ LANGUAGE_TONE_START_MAP,
7
+ SYMBOLS,
8
+ )
9
+
10
+
11
+ # __init__.py は配下のモジュールをインポートした時点で実行される
12
+ # PyTorch のインポートは重いので、型チェック時以外はインポートしない
13
+ if TYPE_CHECKING:
14
+ import torch
15
+
16
+
17
+ __symbol_to_id = {s: i for i, s in enumerate(SYMBOLS)}
18
+
19
+
20
+ def extract_bert_feature(
21
+ text: str,
22
+ word2ph: list[int],
23
+ language: Languages,
24
+ device: str,
25
+ assist_text: Optional[str] = None,
26
+ assist_text_weight: float = 0.7,
27
+ ) -> "torch.Tensor":
28
+ """
29
+ テキストから BERT の特徴量を抽出する
30
+
31
+ Args:
32
+ text (str): テキスト
33
+ word2ph (list[int]): 元のテキストの各文字に音素が何個割り当てられるかを表すリスト
34
+ language (Languages): テキストの言語
35
+ device (str): 推論に利用するデバイス
36
+ assist_text (Optional[str], optional): 補助テキスト (デフォルト: None)
37
+ assist_text_weight (float, optional): 補助テキストの重み (デフォルト: 0.7)
38
+
39
+ Returns:
40
+ torch.Tensor: BERT の特徴量
41
+ """
42
+
43
+ if language == Languages.JP:
44
+ from style_bert_vits2.nlp.japanese.bert_feature import extract_bert_feature
45
+ elif language == Languages.EN:
46
+ from style_bert_vits2.nlp.english.bert_feature import extract_bert_feature
47
+ elif language == Languages.ZH:
48
+ from style_bert_vits2.nlp.chinese.bert_feature import extract_bert_feature
49
+ else:
50
+ raise ValueError(f"Language {language} not supported")
51
+
52
+ return extract_bert_feature(text, word2ph, device, assist_text, assist_text_weight)
53
+
54
+
55
+ def clean_text(
56
+ text: str,
57
+ language: Languages,
58
+ use_jp_extra: bool = True,
59
+ raise_yomi_error: bool = False,
60
+ ) -> tuple[str, list[str], list[int], list[int]]:
61
+ """
62
+ テキストをクリーニングし、音素に変換する
63
+
64
+ Args:
65
+ text (str): クリーニングするテキスト
66
+ language (Languages): テキストの言語
67
+ use_jp_extra (bool, optional): テキストが日本語の場合に JP-Extra モデルを利用するかどうか。Defaults to True.
68
+ raise_yomi_error (bool, optional): False の場合、読めない文字が消えたような扱いとして処理される。Defaults to False.
69
+
70
+ Returns:
71
+ tuple[str, list[str], list[int], list[int]]: クリーニングされたテキストと、音素・アクセント・元のテキストの各文字に音素が何個割り当てられるかのリスト
72
+ """
73
+
74
+ # Changed to import inside if condition to avoid unnecessary import
75
+ if language == Languages.JP:
76
+ from style_bert_vits2.nlp.japanese.g2p import g2p
77
+ from style_bert_vits2.nlp.japanese.normalizer import normalize_text
78
+
79
+ norm_text = normalize_text(text)
80
+ phones, tones, word2ph = g2p(norm_text, use_jp_extra, raise_yomi_error)
81
+ elif language == Languages.EN:
82
+ from style_bert_vits2.nlp.english.g2p import g2p
83
+ from style_bert_vits2.nlp.english.normalizer import normalize_text
84
+
85
+ norm_text = normalize_text(text)
86
+ phones, tones, word2ph = g2p(norm_text)
87
+ elif language == Languages.ZH:
88
+ from style_bert_vits2.nlp.chinese.g2p import g2p
89
+ from style_bert_vits2.nlp.chinese.normalizer import normalize_text
90
+
91
+ norm_text = normalize_text(text)
92
+ phones, tones, word2ph = g2p(norm_text)
93
+ else:
94
+ raise ValueError(f"Language {language} not supported")
95
+
96
+ return norm_text, phones, tones, word2ph
97
+
98
+
99
+ def cleaned_text_to_sequence(
100
+ cleaned_phones: list[str], tones: list[int], language: Languages
101
+ ) -> tuple[list[int], list[int], list[int]]:
102
+ """
103
+ 音素リスト・アクセントリスト・言語を、テキスト内の対応する ID に変換する
104
+
105
+ Args:
106
+ cleaned_phones (list[str]): clean_text() でクリーニングされた音素のリスト
107
+ tones (list[int]): 各音素のアクセント
108
+ language (Languages): テキストの言語
109
+
110
+ Returns:
111
+ tuple[list[int], list[int], list[int]]: List of integers corresponding to the symbols in the text
112
+ """
113
+
114
+ phones = [__symbol_to_id[symbol] for symbol in cleaned_phones]
115
+ tone_start = LANGUAGE_TONE_START_MAP[language]
116
+ tones = [i + tone_start for i in tones]
117
+ lang_id = LANGUAGE_ID_MAP[language]
118
+ lang_ids = [lang_id for i in phones]
119
+
120
+ return phones, tones, lang_ids