Kit-Lemonfoot commited on
Commit
107eeac
1 Parent(s): da4bd39

Upload 77 files

Browse files
.gitignore CHANGED
@@ -1 +1 @@
1
- __pycache__/
 
1
+ __pycache__/
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Hololive Style-Bert-VITS2
3
- emoji: 😊▶️
4
  colorFrom: blue
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: agpl-3.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Style-Bert-VITS2 JVNV
3
+ emoji: 😡😊😱😫
4
  colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
  license: agpl-3.0
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -4,212 +4,39 @@ import argparse
4
  import datetime
5
  import os
6
  import sys
7
- import warnings
8
  import json
 
9
 
10
  import gradio as gr
11
- import numpy as np
12
  import torch
13
- from gradio.processing_utils import convert_to_16_bit_wav
14
-
15
- import utils
16
- from config import config
17
- from infer import get_net_g, infer
18
- from tools.log import logger
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  is_hf_spaces = os.getenv("SYSTEM") == "spaces"
21
  limit = 150
22
 
23
-
24
- class Model:
25
- def __init__(self, model_path, config_path, style_vec_path, device):
26
- self.model_path = model_path
27
- self.config_path = config_path
28
- self.device = device
29
- self.style_vec_path = style_vec_path
30
- self.load()
31
-
32
- def load(self):
33
- self.hps = utils.get_hparams_from_file(self.config_path)
34
- self.spk2id = self.hps.data.spk2id
35
- self.num_styles = self.hps.data.num_styles
36
- if hasattr(self.hps.data, "style2id"):
37
- self.style2id = self.hps.data.style2id
38
- else:
39
- self.style2id = {str(i): i for i in range(self.num_styles)}
40
-
41
- self.style_vectors = np.load(self.style_vec_path)
42
- self.net_g = None
43
-
44
- def load_net_g(self):
45
- self.net_g = get_net_g(
46
- model_path=self.model_path,
47
- version=self.hps.version,
48
- device=self.device,
49
- hps=self.hps,
50
- )
51
-
52
- def get_style_vector(self, style_id, weight=1.0):
53
- mean = self.style_vectors[0]
54
- style_vec = self.style_vectors[style_id]
55
- style_vec = mean + (style_vec - mean) * weight
56
- return style_vec
57
-
58
- def get_style_vector_from_audio(self, audio_path, weight=1.0):
59
- from style_gen import extract_style_vector
60
-
61
- xvec = extract_style_vector(audio_path)
62
- mean = self.style_vectors[0]
63
- xvec = mean + (xvec - mean) * weight
64
- return xvec
65
-
66
- def infer(
67
- self,
68
- text,
69
- language="JP",
70
- sid=0,
71
- reference_audio_path=None,
72
- sdp_ratio=0.2,
73
- noise=0.6,
74
- noisew=0.8,
75
- length=1.0,
76
- line_split=True,
77
- split_interval=0.2,
78
- style_text="",
79
- style_weight=0.7,
80
- use_style_text=False,
81
- style="0",
82
- emotion_weight=1.0,
83
- ):
84
- if reference_audio_path == "":
85
- reference_audio_path = None
86
- if style_text == "" or not use_style_text:
87
- style_text = None
88
-
89
- if self.net_g is None:
90
- self.load_net_g()
91
- if reference_audio_path is None:
92
- style_id = self.style2id[style]
93
- style_vector = self.get_style_vector(style_id, emotion_weight)
94
- else:
95
- style_vector = self.get_style_vector_from_audio(
96
- reference_audio_path, emotion_weight
97
- )
98
- if not line_split:
99
- with torch.no_grad():
100
- audio = infer(
101
- text=text,
102
- sdp_ratio=sdp_ratio,
103
- noise_scale=noise,
104
- noise_scale_w=noisew,
105
- length_scale=length,
106
- sid=sid,
107
- language=language,
108
- hps=self.hps,
109
- net_g=self.net_g,
110
- device=self.device,
111
- style_text=style_text,
112
- style_weight=style_weight,
113
- style_vec=style_vector,
114
- )
115
- else:
116
- texts = text.split("\n")
117
- texts = [t for t in texts if t != ""]
118
- audios = []
119
- with torch.no_grad():
120
- for i, t in enumerate(texts):
121
- audios.append(
122
- infer(
123
- text=t,
124
- sdp_ratio=sdp_ratio,
125
- noise_scale=noise,
126
- noise_scale_w=noisew,
127
- length_scale=length,
128
- sid=sid,
129
- language=language,
130
- hps=self.hps,
131
- net_g=self.net_g,
132
- device=self.device,
133
- style_text=style_text,
134
- style_weight=style_weight,
135
- style_vec=style_vector,
136
- )
137
- )
138
- if i != len(texts) - 1:
139
- audios.append(np.zeros(int(44100 * split_interval)))
140
- audio = np.concatenate(audios)
141
- with warnings.catch_warnings():
142
- warnings.simplefilter("ignore")
143
- audio = convert_to_16_bit_wav(audio)
144
- return (self.hps.data.sampling_rate, audio)
145
-
146
-
147
- class ModelHolder:
148
- def __init__(self, root_dir, device):
149
- self.root_dir = root_dir
150
- self.device = device
151
- self.model_files_dict = {}
152
- self.current_model = None
153
- self.model_names = []
154
- self.models = []
155
- self.refresh()
156
-
157
- def refresh(self):
158
- self.model_files_dict = {}
159
- self.model_names = []
160
- self.current_model = None
161
- model_dirs = [
162
- d
163
- for d in os.listdir(self.root_dir)
164
- if os.path.isdir(os.path.join(self.root_dir, d))
165
- ]
166
- for model_name in model_dirs:
167
- model_dir = os.path.join(self.root_dir, model_name)
168
- model_files = [
169
- os.path.join(model_dir, f)
170
- for f in os.listdir(model_dir)
171
- if f.endswith(".pth") or f.endswith(".pt") or f.endswith(".safetensors")
172
- ]
173
- if len(model_files) == 0:
174
- logger.info(
175
- f"No model files found in {self.root_dir}/{model_name}, so skip it"
176
- )
177
- self.model_files_dict[model_name] = model_files
178
- self.model_names.append(model_name)
179
-
180
- def load_model(self, model_name, model_path):
181
- if model_name not in self.model_files_dict:
182
- raise Exception(f"モデル名{model_name}は存在しません")
183
- if model_path not in self.model_files_dict[model_name]:
184
- raise Exception(f"pthファイル{model_path}は存在しません")
185
- self.current_model = Model(
186
- model_path=model_path,
187
- config_path=os.path.join(self.root_dir, model_name, "config.json"),
188
- style_vec_path=os.path.join(self.root_dir, model_name, "style_vectors.npy"),
189
- device=self.device,
190
- )
191
- styles = list(self.current_model.style2id.keys())
192
- speakers = list(self.current_model.spk2id.keys())
193
- return (
194
- gr.Dropdown(choices=styles, value=styles[0]),
195
- gr.update(interactive=True, value="Synthesize"),
196
- gr.Dropdown(choices=speakers, value=speakers[0]),
197
- )
198
-
199
- def update_model_files_dropdown(self, model_name):
200
- model_files = self.model_files_dict[model_name]
201
- return gr.Dropdown(choices=model_files, value=model_files[0])
202
-
203
- def update_model_names_dropdown(self):
204
- self.refresh()
205
- initial_model_name = self.model_names[0]
206
- initial_model_files = self.model_files_dict[initial_model_name]
207
- return (
208
- gr.Dropdown(choices=self.model_names, value=initial_model_name),
209
- gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]),
210
- gr.update(interactive=False), # For tts_button
211
- )
212
-
213
 
214
  def tts_fn(
215
  model_name,
@@ -223,56 +50,101 @@ def tts_fn(
223
  length_scale,
224
  line_split,
225
  split_interval,
226
- style_text,
 
 
 
227
  style_weight,
228
- use_style_text,
229
- emotion,
230
- emotion_weight,
231
  speaker,
232
  ):
233
  if len(text)<2:
234
- return "Please enter some text.", None
235
- #logger.info(f"Start TTS with {language}:\n{text}")
236
- #logger.info(f"Model: {model_holder.current_model.model_path}")
237
- #logger.info(f"SDP: {sdp_ratio}, Noise: {noise_scale}, Noise_W: {noise_scale_w}, Length: {length_scale}")
238
- #logger.info(f"Style text enabled: {use_style_text}, Style text: {style_text}, Style weight: {style_weight}")
239
- #logger.info(f"Style: {emotion}, Style weight: {emotion_weight}")
240
 
241
  if is_hf_spaces and len(text) > limit:
242
- return f"Too long! There is a character limit of {limit} characters.", None
243
 
244
  if(not model_holder.current_model):
245
- model_holder.load_model(model_name, model_path)
246
-
247
  if(model_holder.current_model.model_path != model_path):
248
- model_holder.load_model(model_name, model_path)
249
-
250
  speaker_id = model_holder.current_model.spk2id[speaker]
251
-
252
  start_time = datetime.datetime.now()
253
 
254
- sr, audio = model_holder.current_model.infer(
255
- text=text,
256
- language=language,
257
- sid=speaker_id,
258
- reference_audio_path=reference_audio_path,
259
- sdp_ratio=sdp_ratio,
260
- noise=noise_scale,
261
- noisew=noise_scale_w,
262
- length=length_scale,
263
- line_split=line_split,
264
- split_interval=split_interval,
265
- style_text=style_text,
266
- style_weight=style_weight,
267
- use_style_text=use_style_text,
268
- style=emotion,
269
- emotion_weight=emotion_weight,
270
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  end_time = datetime.datetime.now()
273
  duration = (end_time - start_time).total_seconds()
274
- logger.info(f"Successful inference, took {duration}s | {speaker} | {sdp_ratio}/{noise_scale}/{noise_scale_w}/{length_scale}/{emotion}/{emotion_weight} | {text}")
275
- return f"Success, time: {duration} seconds.", (sr, audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  def load_voicedata():
278
  print("Loading voice data...")
@@ -286,6 +158,7 @@ def load_voicedata():
286
  model_path = info['model_path']
287
  voice_name = info['title']
288
  speakerid = info['speakerid']
 
289
  image = info['cover']
290
  if not model_path in styledict.keys():
291
  conf=f"model_assets/{model_path}/config.json"
@@ -293,7 +166,7 @@ def load_voicedata():
293
  s2id = hps.data.style2id
294
  styledict[model_path] = s2id.keys()
295
  print(f"Indexed voice {voice_name}")
296
- voices.append((name, model_path, voice_name, speakerid, image))
297
  return voices, styledict
298
 
299
 
@@ -321,7 +194,22 @@ if __name__ == "__main__":
321
  parser = argparse.ArgumentParser()
322
  parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
323
  parser.add_argument(
324
- "--dir", "-d", type=str, help="Model directory", default=config.out_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  )
326
  args = parser.parse_args()
327
  model_dir = args.dir
@@ -383,19 +271,28 @@ if __name__ == "__main__":
383
  label="Text stylization strength",
384
  visible=True,
385
  )
386
-
387
 
388
  with gr.Blocks(theme=gr.themes.Base(primary_hue="emerald", secondary_hue="green"), title="Hololive Style-Bert-VITS2") as app:
389
  gr.Markdown(initial_md)
390
 
391
- for (name, model_path, voice_name, speakerid, image) in voicedata:
 
 
 
 
 
 
 
 
 
 
392
  with gr.TabItem(name):
393
  mn = gr.Textbox(value=model_path, visible=False, interactive=False)
394
- mp = gr.Textbox(value=f"model_assets/{model_path}/{model_path}.safetensors", visible=False, interactive=False)
395
  spk = gr.Textbox(value=speakerid, visible=False, interactive=False)
396
  with gr.Row():
397
  with gr.Column():
398
- gr.Markdown(f"**{voice_name}**\n\nModel name: {model_path}")
399
  gr.Image(f"images/{image}", label=None, show_label=False, width=300, show_download_button=False, container=False, show_share_button=False)
400
  with gr.Column():
401
  with gr.TabItem("Style using a preset"):
@@ -439,9 +336,11 @@ if __name__ == "__main__":
439
  use_style_text,
440
  style,
441
  style_weight,
 
 
442
  spk,
443
  ],
444
- outputs=[text_output, audio_output],
445
  )
446
 
447
  with gr.Row():
 
4
  import datetime
5
  import os
6
  import sys
7
+ from typing import Optional
8
  import json
9
+ import utils
10
 
11
  import gradio as gr
 
12
  import torch
13
+ import yaml
14
+
15
+ from common.constants import (
16
+ DEFAULT_ASSIST_TEXT_WEIGHT,
17
+ DEFAULT_LENGTH,
18
+ DEFAULT_LINE_SPLIT,
19
+ DEFAULT_NOISE,
20
+ DEFAULT_NOISEW,
21
+ DEFAULT_SDP_RATIO,
22
+ DEFAULT_SPLIT_INTERVAL,
23
+ DEFAULT_STYLE,
24
+ DEFAULT_STYLE_WEIGHT,
25
+ Languages,
26
+ )
27
+ from common.log import logger
28
+ from common.tts_model import ModelHolder
29
+ from infer import InvalidToneError
30
+ from text.japanese import g2kata_tone, kata_tone2phone_tone, text_normalize
31
 
32
  is_hf_spaces = os.getenv("SYSTEM") == "spaces"
33
  limit = 150
34
 
35
+ # Get path settings
36
+ with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f:
37
+ path_config: dict[str, str] = yaml.safe_load(f.read())
38
+ # dataset_root = path_config["dataset_root"]
39
+ assets_root = path_config["assets_root"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def tts_fn(
42
  model_name,
 
50
  length_scale,
51
  line_split,
52
  split_interval,
53
+ assist_text,
54
+ assist_text_weight,
55
+ use_assist_text,
56
+ style,
57
  style_weight,
58
+ kata_tone_json_str,
59
+ use_tone,
 
60
  speaker,
61
  ):
62
  if len(text)<2:
63
+ return "Please enter some text.", None, kata_tone_json_str
 
 
 
 
 
64
 
65
  if is_hf_spaces and len(text) > limit:
66
+ return f"Too long! There is a character limit of {limit} characters.", None, kata_tone_json_str
67
 
68
  if(not model_holder.current_model):
69
+ model_holder.load_model_gr(model_name, model_path)
 
70
  if(model_holder.current_model.model_path != model_path):
71
+ model_holder.load_model_gr(model_name, model_path)
 
72
  speaker_id = model_holder.current_model.spk2id[speaker]
 
73
  start_time = datetime.datetime.now()
74
 
75
+ wrong_tone_message = ""
76
+ kata_tone: Optional[list[tuple[str, int]]] = None
77
+ if use_tone and kata_tone_json_str != "":
78
+ if language != "JP":
79
+ #logger.warning("Only Japanese is supported for tone generation.")
80
+ wrong_tone_message = "アクセント指定は現在日本語のみ対応しています。"
81
+ if line_split:
82
+ #logger.warning("Tone generation is not supported for line split.")
83
+ wrong_tone_message = (
84
+ "アクセント指定は改行で分けて生成を使わない場合のみ対応しています。"
85
+ )
86
+ try:
87
+ kata_tone = []
88
+ json_data = json.loads(kata_tone_json_str)
89
+ # tupleを使うように変換
90
+ for kana, tone in json_data:
91
+ assert isinstance(kana, str) and tone in (0, 1), f"{kana}, {tone}"
92
+ kata_tone.append((kana, tone))
93
+ except Exception as e:
94
+ logger.warning(f"Error occurred when parsing kana_tone_json: {e}")
95
+ wrong_tone_message = f"アクセント指定が不正です: {e}"
96
+ kata_tone = None
97
+
98
+ # toneは実際に音声合成に代入される際のみnot Noneになる
99
+ tone: Optional[list[int]] = None
100
+ if kata_tone is not None:
101
+ phone_tone = kata_tone2phone_tone(kata_tone)
102
+ tone = [t for _, t in phone_tone]
103
+
104
+ try:
105
+ sr, audio = model_holder.current_model.infer(
106
+ text=text,
107
+ language=language,
108
+ reference_audio_path=reference_audio_path,
109
+ sdp_ratio=sdp_ratio,
110
+ noise=noise_scale,
111
+ noisew=noise_scale_w,
112
+ length=length_scale,
113
+ line_split=line_split,
114
+ split_interval=split_interval,
115
+ assist_text=assist_text,
116
+ assist_text_weight=assist_text_weight,
117
+ use_assist_text=use_assist_text,
118
+ style=style,
119
+ style_weight=style_weight,
120
+ given_tone=tone,
121
+ sid=speaker_id,
122
+ )
123
+ except InvalidToneError as e:
124
+ logger.error(f"Tone error: {e}")
125
+ return f"Error: アクセント指定が不正です:\n{e}", None, kata_tone_json_str
126
+ except ValueError as e:
127
+ logger.error(f"Value error: {e}")
128
+ return f"Error: {e}", None, kata_tone_json_str
129
 
130
  end_time = datetime.datetime.now()
131
  duration = (end_time - start_time).total_seconds()
132
+
133
+ if tone is None and language == "JP":
134
+ # アクセント指定に使えるようにアクセント情報を返す
135
+ norm_text = text_normalize(text)
136
+ kata_tone = g2kata_tone(norm_text)
137
+ kata_tone_json_str = json.dumps(kata_tone, ensure_ascii=False)
138
+ elif tone is None:
139
+ kata_tone_json_str = ""
140
+
141
+ if reference_audio_path:
142
+ style="External Audio"
143
+ logger.info(f"Successful inference, took {duration}s | {speaker} | {language}/{sdp_ratio}/{noise_scale}/{noise_scale_w}/{length_scale}/{style}/{style_weight} | {text}")
144
+ message = f"Success, time: {duration} seconds."
145
+ if wrong_tone_message != "":
146
+ message = wrong_tone_message + "\n" + message
147
+ return message, (sr, audio), kata_tone_json_str
148
 
149
  def load_voicedata():
150
  print("Loading voice data...")
 
158
  model_path = info['model_path']
159
  voice_name = info['title']
160
  speakerid = info['speakerid']
161
+ datasetauthor = info['datasetauthor']
162
  image = info['cover']
163
  if not model_path in styledict.keys():
164
  conf=f"model_assets/{model_path}/config.json"
 
166
  s2id = hps.data.style2id
167
  styledict[model_path] = s2id.keys()
168
  print(f"Indexed voice {voice_name}")
169
+ voices.append((name, model_path, voice_name, speakerid, datasetauthor, image))
170
  return voices, styledict
171
 
172
 
 
194
  parser = argparse.ArgumentParser()
195
  parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
196
  parser.add_argument(
197
+ "--dir", "-d", type=str, help="Model directory", default=assets_root
198
+ )
199
+ parser.add_argument(
200
+ "--share", action="store_true", help="Share this app publicly", default=False
201
+ )
202
+ parser.add_argument(
203
+ "--server-name",
204
+ type=str,
205
+ default=None,
206
+ help="Server name for Gradio app",
207
+ )
208
+ parser.add_argument(
209
+ "--no-autolaunch",
210
+ action="store_true",
211
+ default=False,
212
+ help="Do not launch app automatically",
213
  )
214
  args = parser.parse_args()
215
  model_dir = args.dir
 
271
  label="Text stylization strength",
272
  visible=True,
273
  )
 
274
 
275
  with gr.Blocks(theme=gr.themes.Base(primary_hue="emerald", secondary_hue="green"), title="Hololive Style-Bert-VITS2") as app:
276
  gr.Markdown(initial_md)
277
 
278
+ #NOT USED SINCE NONE OF MY MODELS ARE JPEXTRA.
279
+ #ONLY HERE FOR COMPATIBILITY WITH THE EXISTING INFER CODE.
280
+ #DO NOT RENDER OR MAKE VISIBLE
281
+ tone = gr.Textbox(
282
+ label="Accent adjustment (0 for low, 1 for high)",
283
+ info="This can only be used when not seperated by line breaks. It is not universal.",
284
+ visible=False
285
+ )
286
+ use_tone = gr.Checkbox(label="Use accent adjustment", value=False, visible=False)
287
+
288
+ for (name, model_path, voice_name, speakerid, datasetauthor, image) in voicedata:
289
  with gr.TabItem(name):
290
  mn = gr.Textbox(value=model_path, visible=False, interactive=False)
291
+ mp = gr.Textbox(value=f"model_assets\\{model_path}\\{model_path}.safetensors", visible=False, interactive=False)
292
  spk = gr.Textbox(value=speakerid, visible=False, interactive=False)
293
  with gr.Row():
294
  with gr.Column():
295
+ gr.Markdown(f"**{voice_name}**\n\nModel name: {model_path} | Dataset author: {datasetauthor}")
296
  gr.Image(f"images/{image}", label=None, show_label=False, width=300, show_download_button=False, container=False, show_share_button=False)
297
  with gr.Column():
298
  with gr.TabItem("Style using a preset"):
 
336
  use_style_text,
337
  style,
338
  style_weight,
339
+ tone,
340
+ use_tone,
341
  spk,
342
  ],
343
+ outputs=[text_output, audio_output, tone],
344
  )
345
 
346
  with gr.Row():
attentions.py CHANGED
@@ -4,7 +4,7 @@ from torch import nn
4
  from torch.nn import functional as F
5
 
6
  import commons
7
- from tools.log import logger as logging
8
 
9
 
10
  class LayerNorm(nn.Module):
 
4
  from torch.nn import functional as F
5
 
6
  import commons
7
+ from common.log import logger as logging
8
 
9
 
10
  class LayerNorm(nn.Module):
common/constants.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+ DEFAULT_STYLE: str = "Neutral"
4
+ DEFAULT_STYLE_WEIGHT: float = 5.0
5
+
6
+
7
+ class Languages(str, enum.Enum):
8
+ JP = "JP"
9
+ EN = "EN"
10
+ ZH = "ZH"
11
+
12
+
13
+ DEFAULT_SDP_RATIO: float = 0.2
14
+ DEFAULT_NOISE: float = 0.6
15
+ DEFAULT_NOISEW: float = 0.8
16
+ DEFAULT_LENGTH: float = 1.0
17
+ DEFAULT_LINE_SPLIT: bool = True
18
+ DEFAULT_SPLIT_INTERVAL: float = 0.5
19
+ DEFAULT_ASSIST_TEXT_WEIGHT: float = 0.7
20
+ DEFAULT_ASSIST_TEXT_WEIGHT: float = 1.0
common/log.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ logger封装
3
+ """
4
+ from loguru import logger
5
+
6
+ from .stdout_wrapper import SAFE_STDOUT
7
+
8
+ # 移除所有默认的处理器
9
+ logger.remove()
10
+
11
+ # 自定义格式并添加到标准输出
12
+ log_format = (
13
+ "<g>{time:MM-DD HH:mm:ss}</g> |<lvl>{level:^8}</lvl>| {file}:{line} | {message}"
14
+ )
15
+
16
+ logger.add(SAFE_STDOUT, format=log_format, backtrace=True, diagnose=True)
common/stdout_wrapper.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import tempfile
3
+
4
+
5
+ class StdoutWrapper:
6
+ def __init__(self):
7
+ self.temp_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
8
+ self.original_stdout = sys.stdout
9
+
10
+ def write(self, message: str):
11
+ self.temp_file.write(message)
12
+ self.temp_file.flush()
13
+ print(message, end="", file=self.original_stdout)
14
+
15
+ def flush(self):
16
+ self.temp_file.flush()
17
+
18
+ def read(self):
19
+ self.temp_file.seek(0)
20
+ return self.temp_file.read()
21
+
22
+ def close(self):
23
+ self.temp_file.close()
24
+
25
+ def fileno(self):
26
+ return self.temp_file.fileno()
27
+
28
+
29
+ try:
30
+ import google.colab
31
+
32
+ SAFE_STDOUT = StdoutWrapper()
33
+ except ImportError:
34
+ SAFE_STDOUT = sys.stdout
common/subprocess_utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+ from .log import logger
5
+ from .stdout_wrapper import SAFE_STDOUT
6
+
7
+ python = sys.executable
8
+
9
+
10
+ def run_script_with_log(cmd: list[str], ignore_warning=False) -> tuple[bool, str]:
11
+ logger.info(f"Running: {' '.join(cmd)}")
12
+ result = subprocess.run(
13
+ [python] + cmd,
14
+ stdout=SAFE_STDOUT, # type: ignore
15
+ stderr=subprocess.PIPE,
16
+ text=True,
17
+ )
18
+ if result.returncode != 0:
19
+ logger.error(f"Error: {' '.join(cmd)}\n{result.stderr}")
20
+ return False, result.stderr
21
+ elif result.stderr and not ignore_warning:
22
+ logger.warning(f"Warning: {' '.join(cmd)}\n{result.stderr}")
23
+ return True, result.stderr
24
+ logger.success(f"Success: {' '.join(cmd)}")
25
+ return True, ""
26
+
27
+
28
+ def second_elem_of(original_function):
29
+ def inner_function(*args, **kwargs):
30
+ return original_function(*args, **kwargs)[1]
31
+
32
+ return inner_function
common/tts_model.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import torch
4
+ import os
5
+ import warnings
6
+ from gradio.processing_utils import convert_to_16_bit_wav
7
+ from typing import Dict, List, Optional, Union
8
+
9
+ import utils
10
+ from infer import get_net_g, infer
11
+ from models import SynthesizerTrn
12
+ from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra
13
+
14
+ from .log import logger
15
+ from .constants import (
16
+ DEFAULT_ASSIST_TEXT_WEIGHT,
17
+ DEFAULT_LENGTH,
18
+ DEFAULT_LINE_SPLIT,
19
+ DEFAULT_NOISE,
20
+ DEFAULT_NOISEW,
21
+ DEFAULT_SDP_RATIO,
22
+ DEFAULT_SPLIT_INTERVAL,
23
+ DEFAULT_STYLE,
24
+ DEFAULT_STYLE_WEIGHT,
25
+ )
26
+
27
+
28
+ class Model:
29
+ def __init__(
30
+ self, model_path: str, config_path: str, style_vec_path: str, device: str
31
+ ):
32
+ self.model_path: str = model_path
33
+ self.config_path: str = config_path
34
+ self.device: str = device
35
+ self.style_vec_path: str = style_vec_path
36
+ self.hps: utils.HParams = utils.get_hparams_from_file(self.config_path)
37
+ self.spk2id: Dict[str, int] = self.hps.data.spk2id
38
+ self.id2spk: Dict[int, str] = {v: k for k, v in self.spk2id.items()}
39
+
40
+ self.num_styles: int = self.hps.data.num_styles
41
+ if hasattr(self.hps.data, "style2id"):
42
+ self.style2id: Dict[str, int] = self.hps.data.style2id
43
+ else:
44
+ self.style2id: Dict[str, int] = {str(i): i for i in range(self.num_styles)}
45
+ if len(self.style2id) != self.num_styles:
46
+ raise ValueError(
47
+ f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})"
48
+ )
49
+
50
+ self.style_vectors: np.ndarray = np.load(self.style_vec_path)
51
+ if self.style_vectors.shape[0] != self.num_styles:
52
+ raise ValueError(
53
+ f"The number of styles ({self.num_styles}) does not match the number of style vectors ({self.style_vectors.shape[0]})"
54
+ )
55
+
56
+ self.net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None
57
+
58
+ def load_net_g(self):
59
+ self.net_g = get_net_g(
60
+ model_path=self.model_path,
61
+ version=self.hps.version,
62
+ device=self.device,
63
+ hps=self.hps,
64
+ )
65
+
66
+ def get_style_vector(self, style_id: int, weight: float = 1.0) -> np.ndarray:
67
+ mean = self.style_vectors[0]
68
+ style_vec = self.style_vectors[style_id]
69
+ style_vec = mean + (style_vec - mean) * weight
70
+ return style_vec
71
+
72
+ def get_style_vector_from_audio(
73
+ self, audio_path: str, weight: float = 1.0
74
+ ) -> np.ndarray:
75
+ from style_gen import get_style_vector
76
+
77
+ xvec = get_style_vector(audio_path)
78
+ mean = self.style_vectors[0]
79
+ xvec = mean + (xvec - mean) * weight
80
+ return xvec
81
+
82
+ def infer(
83
+ self,
84
+ text: str,
85
+ language: str = "JP",
86
+ sid: int = 0,
87
+ reference_audio_path: Optional[str] = None,
88
+ sdp_ratio: float = DEFAULT_SDP_RATIO,
89
+ noise: float = DEFAULT_NOISE,
90
+ noisew: float = DEFAULT_NOISEW,
91
+ length: float = DEFAULT_LENGTH,
92
+ line_split: bool = DEFAULT_LINE_SPLIT,
93
+ split_interval: float = DEFAULT_SPLIT_INTERVAL,
94
+ assist_text: Optional[str] = None,
95
+ assist_text_weight: float = DEFAULT_ASSIST_TEXT_WEIGHT,
96
+ use_assist_text: bool = False,
97
+ style: str = DEFAULT_STYLE,
98
+ style_weight: float = DEFAULT_STYLE_WEIGHT,
99
+ given_tone: Optional[list[int]] = None,
100
+ ) -> tuple[int, np.ndarray]:
101
+ #logger.info(f"Start generating audio data from text:\n{text}")
102
+ if language != "JP" and self.hps.version.endswith("JP-Extra"):
103
+ raise ValueError(
104
+ "The model is trained with JP-Extra, but the language is not JP"
105
+ )
106
+ if reference_audio_path == "":
107
+ reference_audio_path = None
108
+ if assist_text == "" or not use_assist_text:
109
+ assist_text = None
110
+
111
+ if self.net_g is None:
112
+ self.load_net_g()
113
+ if reference_audio_path is None:
114
+ style_id = self.style2id[style]
115
+ style_vector = self.get_style_vector(style_id, style_weight)
116
+ else:
117
+ style_vector = self.get_style_vector_from_audio(
118
+ reference_audio_path, style_weight
119
+ )
120
+ if not line_split:
121
+ with torch.no_grad():
122
+ audio = infer(
123
+ text=text,
124
+ sdp_ratio=sdp_ratio,
125
+ noise_scale=noise,
126
+ noise_scale_w=noisew,
127
+ length_scale=length,
128
+ sid=sid,
129
+ language=language,
130
+ hps=self.hps,
131
+ net_g=self.net_g,
132
+ device=self.device,
133
+ assist_text=assist_text,
134
+ assist_text_weight=assist_text_weight,
135
+ style_vec=style_vector,
136
+ given_tone=given_tone,
137
+ )
138
+ else:
139
+ texts = text.split("\n")
140
+ texts = [t for t in texts if t != ""]
141
+ audios = []
142
+ with torch.no_grad():
143
+ for i, t in enumerate(texts):
144
+ audios.append(
145
+ infer(
146
+ text=t,
147
+ sdp_ratio=sdp_ratio,
148
+ noise_scale=noise,
149
+ noise_scale_w=noisew,
150
+ length_scale=length,
151
+ sid=sid,
152
+ language=language,
153
+ hps=self.hps,
154
+ net_g=self.net_g,
155
+ device=self.device,
156
+ assist_text=assist_text,
157
+ assist_text_weight=assist_text_weight,
158
+ style_vec=style_vector,
159
+ )
160
+ )
161
+ if i != len(texts) - 1:
162
+ audios.append(np.zeros(int(44100 * split_interval)))
163
+ audio = np.concatenate(audios)
164
+ with warnings.catch_warnings():
165
+ warnings.simplefilter("ignore")
166
+ audio = convert_to_16_bit_wav(audio)
167
+ #logger.info("Audio data generated successfully")
168
+ return (self.hps.data.sampling_rate, audio)
169
+
170
+
171
+ class ModelHolder:
172
+ def __init__(self, root_dir: str, device: str):
173
+ self.root_dir: str = root_dir
174
+ self.device: str = device
175
+ self.model_files_dict: Dict[str, List[str]] = {}
176
+ self.current_model: Optional[Model] = None
177
+ self.model_names: List[str] = []
178
+ self.models: List[Model] = []
179
+ self.refresh()
180
+
181
+ def refresh(self):
182
+ self.model_files_dict = {}
183
+ self.model_names = []
184
+ self.current_model = None
185
+ model_dirs = [
186
+ d
187
+ for d in os.listdir(self.root_dir)
188
+ if os.path.isdir(os.path.join(self.root_dir, d))
189
+ ]
190
+ for model_name in model_dirs:
191
+ model_dir = os.path.join(self.root_dir, model_name)
192
+ model_files = [
193
+ os.path.join(model_dir, f)
194
+ for f in os.listdir(model_dir)
195
+ if f.endswith(".pth") or f.endswith(".pt") or f.endswith(".safetensors")
196
+ ]
197
+ if len(model_files) == 0:
198
+ logger.warning(
199
+ f"No model files found in {self.root_dir}/{model_name}, so skip it"
200
+ )
201
+ continue
202
+ self.model_files_dict[model_name] = model_files
203
+ self.model_names.append(model_name)
204
+
205
+ def load_model_gr(
206
+ self, model_name: str, model_path: str
207
+ ) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]:
208
+ if model_name not in self.model_files_dict:
209
+ raise ValueError(f"Model `{model_name}` is not found")
210
+ if model_path not in self.model_files_dict[model_name]:
211
+ raise ValueError(f"Model file `{model_path}` is not found")
212
+ if (
213
+ self.current_model is not None
214
+ and self.current_model.model_path == model_path
215
+ ):
216
+ # Already loaded
217
+ speakers = list(self.current_model.spk2id.keys())
218
+ styles = list(self.current_model.style2id.keys())
219
+ return (
220
+ gr.Dropdown(choices=styles, value=styles[0]),
221
+ gr.Button(interactive=True, value="音声合成"),
222
+ gr.Dropdown(choices=speakers, value=speakers[0]),
223
+ )
224
+ self.current_model = Model(
225
+ model_path=model_path,
226
+ config_path=os.path.join(self.root_dir, model_name, "config.json"),
227
+ style_vec_path=os.path.join(self.root_dir, model_name, "style_vectors.npy"),
228
+ device=self.device,
229
+ )
230
+ speakers = list(self.current_model.spk2id.keys())
231
+ styles = list(self.current_model.style2id.keys())
232
+ return (
233
+ gr.Dropdown(choices=styles, value=styles[0]),
234
+ gr.Button(interactive=True, value="音声合成"),
235
+ gr.Dropdown(choices=speakers, value=speakers[0]),
236
+ )
237
+
238
+ def update_model_files_gr(self, model_name: str) -> gr.Dropdown:
239
+ model_files = self.model_files_dict[model_name]
240
+ return gr.Dropdown(choices=model_files, value=model_files[0])
241
+
242
+ def update_model_names_gr(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]:
243
+ self.refresh()
244
+ initial_model_name = self.model_names[0]
245
+ initial_model_files = self.model_files_dict[initial_model_name]
246
+ return (
247
+ gr.Dropdown(choices=self.model_names, value=initial_model_name),
248
+ gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]),
249
+ gr.Button(interactive=False), # For tts_button
250
+ )
config.py CHANGED
@@ -1,254 +1,269 @@
1
- """
2
- @Desc: 全局配置文件读取
3
- """
4
- import argparse
5
- import yaml
6
- from typing import Dict, List
7
- import os
8
- import shutil
9
- import sys
10
-
11
-
12
- class Resample_config:
13
- """重采样配置"""
14
-
15
- def __init__(self, in_dir: str, out_dir: str, sampling_rate: int = 44100):
16
- self.sampling_rate: int = sampling_rate # 目标采样率
17
- self.in_dir: str = in_dir # 待处理音频目录路径
18
- self.out_dir: str = out_dir # 重采样输出路径
19
-
20
- @classmethod
21
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
22
- """从字典中生成实例"""
23
-
24
- # 不检查路径是否有效,此逻辑在resample.py中处理
25
- data["in_dir"] = os.path.join(dataset_path, data["in_dir"])
26
- data["out_dir"] = os.path.join(dataset_path, data["out_dir"])
27
-
28
- return cls(**data)
29
-
30
-
31
- class Preprocess_text_config:
32
- """数据预处理配置"""
33
-
34
- def __init__(
35
- self,
36
- transcription_path: str,
37
- cleaned_path: str,
38
- train_path: str,
39
- val_path: str,
40
- config_path: str,
41
- val_per_lang: int = 5,
42
- max_val_total: int = 10000,
43
- clean: bool = True,
44
- ):
45
- self.transcription_path: str = transcription_path # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
46
- self.cleaned_path: str = cleaned_path # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
47
- self.train_path: str = train_path # 训练集路径,可以不填。不填则将在原始文本目录生成
48
- self.val_path: str = val_path # 验证集路径,可以不填。不填则将在原始文本目录生成
49
- self.config_path: str = config_path # 配置文件路径
50
- self.val_per_lang: int = val_per_lang # 每个speaker的验证集条数
51
- self.max_val_total: int = max_val_total # 验证集最大条数,多于的会被截断并放到训练集中
52
- self.clean: bool = clean # 是否进行数据清洗
53
-
54
- @classmethod
55
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
56
- """从字典中生成实例"""
57
-
58
- data["transcription_path"] = os.path.join(
59
- dataset_path, data["transcription_path"]
60
- )
61
- if data["cleaned_path"] == "" or data["cleaned_path"] is None:
62
- data["cleaned_path"] = None
63
- else:
64
- data["cleaned_path"] = os.path.join(dataset_path, data["cleaned_path"])
65
- data["train_path"] = os.path.join(dataset_path, data["train_path"])
66
- data["val_path"] = os.path.join(dataset_path, data["val_path"])
67
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
68
-
69
- return cls(**data)
70
-
71
-
72
- class Bert_gen_config:
73
- """bert_gen 配置"""
74
-
75
- def __init__(
76
- self,
77
- config_path: str,
78
- num_processes: int = 2,
79
- device: str = "cuda",
80
- use_multi_device: bool = False,
81
- ):
82
- self.config_path = config_path
83
- self.num_processes = num_processes
84
- self.device = device
85
- self.use_multi_device = use_multi_device
86
-
87
- @classmethod
88
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
89
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
90
-
91
- return cls(**data)
92
-
93
-
94
- class Style_gen_config:
95
- """style_gen 配置"""
96
-
97
- def __init__(
98
- self,
99
- config_path: str,
100
- num_processes: int = 2,
101
- device: str = "cuda",
102
- ):
103
- self.config_path = config_path
104
- self.num_processes = num_processes
105
- self.device = device
106
-
107
- @classmethod
108
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
109
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
110
-
111
- return cls(**data)
112
-
113
-
114
- class Train_ms_config:
115
- """训练配置"""
116
-
117
- def __init__(
118
- self,
119
- config_path: str,
120
- env: Dict[str, any],
121
- # base: Dict[str, any],
122
- model: str,
123
- num_workers: int,
124
- spec_cache: bool,
125
- keep_ckpts: int,
126
- ):
127
- self.env = env # 需要加载的环境变量
128
- # self.base = base # 底模配置
129
- self.model = model # 训练模型存储目录,该路径为相对于dataset_path的路径,而非项目根目录
130
- self.config_path = config_path # 配置文件路径
131
- self.num_workers = num_workers # worker数量
132
- self.spec_cache = spec_cache # 是否启用spec缓存
133
- self.keep_ckpts = keep_ckpts # ckpt数量
134
-
135
- @classmethod
136
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
137
- # data["model"] = os.path.join(dataset_path, data["model"])
138
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
139
-
140
- return cls(**data)
141
-
142
-
143
- class Webui_config:
144
- """webui 配置"""
145
-
146
- def __init__(
147
- self,
148
- device: str,
149
- model: str,
150
- config_path: str,
151
- language_identification_library: str,
152
- port: int = 7860,
153
- share: bool = False,
154
- debug: bool = False,
155
- ):
156
- self.device: str = device
157
- self.model: str = model # 端口号
158
- self.config_path: str = config_path # 是否公开部署,对外网开放
159
- self.port: int = port # 是否开启debug模式
160
- self.share: bool = share # 模型路径
161
- self.debug: bool = debug # 配置文件路径
162
- self.language_identification_library: str = (
163
- language_identification_library # 语种识别库
164
- )
165
-
166
- @classmethod
167
- def from_dict(cls, dataset_path: str, data: Dict[str, any]):
168
- data["config_path"] = os.path.join(dataset_path, data["config_path"])
169
- data["model"] = os.path.join(dataset_path, data["model"])
170
- return cls(**data)
171
-
172
-
173
- class Server_config:
174
- def __init__(
175
- self, models: List[Dict[str, any]], port: int = 5000, device: str = "cuda"
176
- ):
177
- self.models: List[Dict[str, any]] = models # 需要加载的所有模型的配置
178
- self.port: int = port # 端口号
179
- self.device: str = device # 模型默认使用设备
180
-
181
- @classmethod
182
- def from_dict(cls, data: Dict[str, any]):
183
- return cls(**data)
184
-
185
-
186
- class Translate_config:
187
- """翻译api配置"""
188
-
189
- def __init__(self, app_key: str, secret_key: str):
190
- self.app_key = app_key
191
- self.secret_key = secret_key
192
-
193
- @classmethod
194
- def from_dict(cls, data: Dict[str, any]):
195
- return cls(**data)
196
-
197
-
198
- class Config:
199
- def __init__(self, config_path: str):
200
- if not os.path.isfile(config_path) and os.path.isfile("default_config.yml"):
201
- shutil.copy(src="default_config.yml", dst=config_path)
202
- print(
203
- f"A configuration file {config_path} has been generated based on the default configuration file default_config.yml."
204
- )
205
- print(
206
- "If you have no special needs, please do not modify default_config.yml."
207
- )
208
- # sys.exit(0)
209
- with open(file=config_path, mode="r", encoding="utf-8") as file:
210
- yaml_config: Dict[str, any] = yaml.safe_load(file.read())
211
- model_name: str = yaml_config["model_name"]
212
- self.model_name: str = model_name
213
- if "dataset_path" in yaml_config:
214
- dataset_path = yaml_config["dataset_path"]
215
- else:
216
- dataset_path = f"Data/{model_name}"
217
- self.out_dir = yaml_config["out_dir"]
218
- # openi_token: str = yaml_config["openi_token"]
219
- self.dataset_path: str = dataset_path
220
- # self.mirror: str = yaml_config["mirror"]
221
- # self.openi_token: str = openi_token
222
- self.resample_config: Resample_config = Resample_config.from_dict(
223
- dataset_path, yaml_config["resample"]
224
- )
225
- self.preprocess_text_config: Preprocess_text_config = (
226
- Preprocess_text_config.from_dict(
227
- dataset_path, yaml_config["preprocess_text"]
228
- )
229
- )
230
- self.bert_gen_config: Bert_gen_config = Bert_gen_config.from_dict(
231
- dataset_path, yaml_config["bert_gen"]
232
- )
233
- self.style_gen_config: Style_gen_config = Style_gen_config.from_dict(
234
- dataset_path, yaml_config["style_gen"]
235
- )
236
- self.train_ms_config: Train_ms_config = Train_ms_config.from_dict(
237
- dataset_path, yaml_config["train_ms"]
238
- )
239
- self.webui_config: Webui_config = Webui_config.from_dict(
240
- dataset_path, yaml_config["webui"]
241
- )
242
- self.server_config: Server_config = Server_config.from_dict(
243
- yaml_config["server"]
244
- )
245
- # self.translate_config: Translate_config = Translate_config.from_dict(
246
- # yaml_config["translate"]
247
- # )
248
-
249
-
250
- parser = argparse.ArgumentParser()
251
- # 为避免与以前的config.json起冲突,将其更名如下
252
- parser.add_argument("-y", "--yml_config", type=str, default="config.yml")
253
- args, _ = parser.parse_known_args()
254
- config = Config(args.yml_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Desc: 全局配置文件读取
3
+ """
4
+ import argparse
5
+ import os
6
+ import shutil
7
+ from typing import Dict, List
8
+
9
+ import yaml
10
+
11
+ from common.log import logger
12
+
13
+
14
+ class Resample_config:
15
+ """重采样配置"""
16
+
17
+ def __init__(self, in_dir: str, out_dir: str, sampling_rate: int = 44100):
18
+ self.sampling_rate: int = sampling_rate # 目标采样率
19
+ self.in_dir: str = in_dir # 待处理音频目录路径
20
+ self.out_dir: str = out_dir # 重采样输出路径
21
+
22
+ @classmethod
23
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
24
+ """从字典中生成实例"""
25
+
26
+ # 不检查路���是否有效,此逻辑在resample.py中处理
27
+ data["in_dir"] = os.path.join(dataset_path, data["in_dir"])
28
+ data["out_dir"] = os.path.join(dataset_path, data["out_dir"])
29
+
30
+ return cls(**data)
31
+
32
+
33
+ class Preprocess_text_config:
34
+ """数据预处理配置"""
35
+
36
+ def __init__(
37
+ self,
38
+ transcription_path: str,
39
+ cleaned_path: str,
40
+ train_path: str,
41
+ val_path: str,
42
+ config_path: str,
43
+ val_per_lang: int = 5,
44
+ max_val_total: int = 10000,
45
+ clean: bool = True,
46
+ ):
47
+ self.transcription_path: str = transcription_path # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
48
+ self.cleaned_path: str = cleaned_path # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
49
+ self.train_path: str = train_path # 训练集路径,可以不填。不填则将在原始文本目录生成
50
+ self.val_path: str = val_path # 验证集路径,可以不填。不填则将在原始文本目录生成
51
+ self.config_path: str = config_path # 配置文件路径
52
+ self.val_per_lang: int = val_per_lang # 每个speaker的验证集条数
53
+ self.max_val_total: int = max_val_total # 验证集最大条数,多于的会被截断并放到训练集中
54
+ self.clean: bool = clean # 是否进行数据清洗
55
+
56
+ @classmethod
57
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
58
+ """从字典中生成实例"""
59
+
60
+ data["transcription_path"] = os.path.join(
61
+ dataset_path, data["transcription_path"]
62
+ )
63
+ if data["cleaned_path"] == "" or data["cleaned_path"] is None:
64
+ data["cleaned_path"] = None
65
+ else:
66
+ data["cleaned_path"] = os.path.join(dataset_path, data["cleaned_path"])
67
+ data["train_path"] = os.path.join(dataset_path, data["train_path"])
68
+ data["val_path"] = os.path.join(dataset_path, data["val_path"])
69
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
70
+
71
+ return cls(**data)
72
+
73
+
74
+ class Bert_gen_config:
75
+ """bert_gen 配置"""
76
+
77
+ def __init__(
78
+ self,
79
+ config_path: str,
80
+ num_processes: int = 2,
81
+ device: str = "cuda",
82
+ use_multi_device: bool = False,
83
+ ):
84
+ self.config_path = config_path
85
+ self.num_processes = num_processes
86
+ self.device = device
87
+ self.use_multi_device = use_multi_device
88
+
89
+ @classmethod
90
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
91
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
92
+
93
+ return cls(**data)
94
+
95
+
96
+ class Style_gen_config:
97
+ """style_gen 配置"""
98
+
99
+ def __init__(
100
+ self,
101
+ config_path: str,
102
+ num_processes: int = 4,
103
+ device: str = "cuda",
104
+ ):
105
+ self.config_path = config_path
106
+ self.num_processes = num_processes
107
+ self.device = device
108
+
109
+ @classmethod
110
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
111
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
112
+
113
+ return cls(**data)
114
+
115
+
116
+ class Train_ms_config:
117
+ """训练配置"""
118
+
119
+ def __init__(
120
+ self,
121
+ config_path: str,
122
+ env: Dict[str, any],
123
+ # base: Dict[str, any],
124
+ model_dir: str,
125
+ num_workers: int,
126
+ spec_cache: bool,
127
+ keep_ckpts: int,
128
+ ):
129
+ self.env = env # 需要加载的环境变量
130
+ # self.base = base # 底模配置
131
+ self.model_dir = model_dir # 训练模型存储目录,该路径为相对于dataset_path的路径,而非项目根目录
132
+ self.config_path = config_path # 配置文件路径
133
+ self.num_workers = num_workers # worker数量
134
+ self.spec_cache = spec_cache # 是否启用spec缓存
135
+ self.keep_ckpts = keep_ckpts # ckpt数量
136
+
137
+ @classmethod
138
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
139
+ # data["model"] = os.path.join(dataset_path, data["model"])
140
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
141
+
142
+ return cls(**data)
143
+
144
+
145
+ class Webui_config:
146
+ """webui 配置"""
147
+
148
+ def __init__(
149
+ self,
150
+ device: str,
151
+ model: str,
152
+ config_path: str,
153
+ language_identification_library: str,
154
+ port: int = 7860,
155
+ share: bool = False,
156
+ debug: bool = False,
157
+ ):
158
+ self.device: str = device
159
+ self.model: str = model # 端口号
160
+ self.config_path: str = config_path # 是否公开部署,对外网开放
161
+ self.port: int = port # 是否开启debug模式
162
+ self.share: bool = share # 模型路径
163
+ self.debug: bool = debug # 配置文件路径
164
+ self.language_identification_library: str = (
165
+ language_identification_library # 语种识别库
166
+ )
167
+
168
+ @classmethod
169
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
170
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
171
+ data["model"] = os.path.join(dataset_path, data["model"])
172
+ return cls(**data)
173
+
174
+
175
+ class Server_config:
176
+ def __init__(
177
+ self,
178
+ port: int = 5000,
179
+ device: str = "cuda",
180
+ limit: int = 100,
181
+ language: str = "JP",
182
+ origins: List[str] = None,
183
+ ):
184
+ self.port: int = port
185
+ self.device: str = device
186
+ self.language: str = language
187
+ self.limit: int = limit
188
+ self.origins: List[str] = origins
189
+
190
+ @classmethod
191
+ def from_dict(cls, data: Dict[str, any]):
192
+ return cls(**data)
193
+
194
+
195
+ class Translate_config:
196
+ """翻译api配置"""
197
+
198
+ def __init__(self, app_key: str, secret_key: str):
199
+ self.app_key = app_key
200
+ self.secret_key = secret_key
201
+
202
+ @classmethod
203
+ def from_dict(cls, data: Dict[str, any]):
204
+ return cls(**data)
205
+
206
+
207
+ class Config:
208
+ def __init__(self, config_path: str, path_config: dict[str, str]):
209
+ if not os.path.isfile(config_path) and os.path.isfile("default_config.yml"):
210
+ shutil.copy(src="default_config.yml", dst=config_path)
211
+ logger.info(
212
+ f"A configuration file {config_path} has been generated based on the default configuration file default_config.yml."
213
+ )
214
+ logger.info(
215
+ "If you have no special needs, please do not modify default_config.yml."
216
+ )
217
+ # sys.exit(0)
218
+ with open(file=config_path, mode="r", encoding="utf-8") as file:
219
+ yaml_config: Dict[str, any] = yaml.safe_load(file.read())
220
+ model_name: str = yaml_config["model_name"]
221
+ self.model_name: str = model_name
222
+ if "dataset_path" in yaml_config:
223
+ dataset_path = yaml_config["dataset_path"]
224
+ else:
225
+ dataset_path = os.path.join(path_config["dataset_root"], model_name)
226
+ self.dataset_path: str = dataset_path
227
+ self.assets_root: str = path_config["assets_root"]
228
+ self.out_dir = os.path.join(self.assets_root, model_name)
229
+ self.resample_config: Resample_config = Resample_config.from_dict(
230
+ dataset_path, yaml_config["resample"]
231
+ )
232
+ self.preprocess_text_config: Preprocess_text_config = (
233
+ Preprocess_text_config.from_dict(
234
+ dataset_path, yaml_config["preprocess_text"]
235
+ )
236
+ )
237
+ self.bert_gen_config: Bert_gen_config = Bert_gen_config.from_dict(
238
+ dataset_path, yaml_config["bert_gen"]
239
+ )
240
+ self.style_gen_config: Style_gen_config = Style_gen_config.from_dict(
241
+ dataset_path, yaml_config["style_gen"]
242
+ )
243
+ self.train_ms_config: Train_ms_config = Train_ms_config.from_dict(
244
+ dataset_path, yaml_config["train_ms"]
245
+ )
246
+ self.webui_config: Webui_config = Webui_config.from_dict(
247
+ dataset_path, yaml_config["webui"]
248
+ )
249
+ self.server_config: Server_config = Server_config.from_dict(
250
+ yaml_config["server"]
251
+ )
252
+ # self.translate_config: Translate_config = Translate_config.from_dict(
253
+ # yaml_config["translate"]
254
+ # )
255
+
256
+
257
+ with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f:
258
+ path_config: dict[str, str] = yaml.safe_load(f.read())
259
+ # Should contain the following keys:
260
+ # - dataset_root: the root directory of the dataset, default to "Data"
261
+ # - assets_root: the root directory of the assets, default to "model_assets"
262
+
263
+
264
+ try:
265
+ config = Config("config.yml", path_config)
266
+ except (TypeError, KeyError):
267
+ logger.warning("Old config.yml found. Replace it with default_config.yml.")
268
+ shutil.copy(src="default_config.yml", dst="config.yml")
269
+ config = Config("config.yml", path_config)
config.yml CHANGED
@@ -1,36 +1,29 @@
1
  bert_gen:
2
  config_path: config.json
3
  device: cpu
4
- num_processes: 4
5
  use_multi_device: false
6
- dataset_path: Data\jvnv-M2
7
- model_name: jvnv-M2
8
- out_dir: model_assets
9
  preprocess_text:
10
  clean: true
11
  cleaned_path: ''
12
  config_path: config.json
13
  max_val_total: 12
14
- train_path: filelists/train.list
15
- transcription_path: filelists/text.list
16
- val_path: filelists/val.list
17
  val_per_lang: 4
18
  resample:
19
- in_dir: audios/raw
20
- out_dir: audios/wavs
21
  sampling_rate: 44100
22
  server:
23
- device: cpu
24
- models:
25
- - config: ''
26
- device: cpu
27
- language: ZH
28
- model: ''
29
- - config: ''
30
- device: cpu
31
- language: JP
32
- model: ''
33
- speakers: []
34
  port: 5000
35
  style_gen:
36
  config_path: config.json
@@ -45,13 +38,13 @@ train_ms:
45
  RANK: 0
46
  WORLD_SIZE: 1
47
  keep_ckpts: 1
48
- model: models
49
  num_workers: 16
50
  spec_cache: true
51
  webui:
52
  config_path: config.json
53
  debug: false
54
- device: cpu
55
  language_identification_library: langid
56
  model: models/G_8000.pth
57
  port: 7860
 
1
  bert_gen:
2
  config_path: config.json
3
  device: cpu
4
+ num_processes: 2
5
  use_multi_device: false
6
+ dataset_path: Data\model_name
7
+ model_name: model_name
 
8
  preprocess_text:
9
  clean: true
10
  cleaned_path: ''
11
  config_path: config.json
12
  max_val_total: 12
13
+ train_path: train.list
14
+ transcription_path: esd.list
15
+ val_path: val.list
16
  val_per_lang: 4
17
  resample:
18
+ in_dir: raw
19
+ out_dir: wavs
20
  sampling_rate: 44100
21
  server:
22
+ device: cuda
23
+ language: JP
24
+ limit: 100
25
+ origins:
26
+ - '*'
 
 
 
 
 
 
27
  port: 5000
28
  style_gen:
29
  config_path: config.json
 
38
  RANK: 0
39
  WORLD_SIZE: 1
40
  keep_ckpts: 1
41
+ model_dir: models
42
  num_workers: 16
43
  spec_cache: true
44
  webui:
45
  config_path: config.json
46
  debug: false
47
+ device: cuda
48
  language_identification_library: langid
49
  model: models/G_8000.pth
50
  port: 7860
configs/config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "your_model_name",
3
+ "train": {
4
+ "log_interval": 200,
5
+ "eval_interval": 1000,
6
+ "seed": 42,
7
+ "epochs": 1000,
8
+ "learning_rate": 0.0002,
9
+ "betas": [0.8, 0.99],
10
+ "eps": 1e-9,
11
+ "batch_size": 4,
12
+ "bf16_run": true,
13
+ "lr_decay": 0.99995,
14
+ "segment_size": 16384,
15
+ "init_lr_ratio": 1,
16
+ "warmup_epochs": 0,
17
+ "c_mel": 45,
18
+ "c_kl": 1.0,
19
+ "skip_optimizer": false,
20
+ "freeze_ZH_bert": false,
21
+ "freeze_JP_bert": false,
22
+ "freeze_EN_bert": false,
23
+ "freeze_style": false
24
+ },
25
+ "data": {
26
+ "training_files": "Data/your_model_name/filelists/train.list",
27
+ "validation_files": "Data/your_model_name/filelists/val.list",
28
+ "max_wav_value": 32768.0,
29
+ "sampling_rate": 44100,
30
+ "filter_length": 2048,
31
+ "hop_length": 512,
32
+ "win_length": 2048,
33
+ "n_mel_channels": 128,
34
+ "mel_fmin": 0.0,
35
+ "mel_fmax": null,
36
+ "add_blank": true,
37
+ "n_speakers": 1,
38
+ "cleaned_text": true,
39
+ "num_styles": 1,
40
+ "style2id": {
41
+ "Neutral": 0
42
+ }
43
+ },
44
+ "model": {
45
+ "use_spk_conditioned_encoder": true,
46
+ "use_noise_scaled_mas": true,
47
+ "use_mel_posterior_encoder": false,
48
+ "use_duration_discriminator": true,
49
+ "inter_channels": 192,
50
+ "hidden_channels": 192,
51
+ "filter_channels": 768,
52
+ "n_heads": 2,
53
+ "n_layers": 6,
54
+ "kernel_size": 3,
55
+ "p_dropout": 0.1,
56
+ "resblock": "1",
57
+ "resblock_kernel_sizes": [3, 7, 11],
58
+ "resblock_dilation_sizes": [
59
+ [1, 3, 5],
60
+ [1, 3, 5],
61
+ [1, 3, 5]
62
+ ],
63
+ "upsample_rates": [8, 8, 2, 2, 2],
64
+ "upsample_initial_channel": 512,
65
+ "upsample_kernel_sizes": [16, 16, 8, 2, 2],
66
+ "n_layers_q": 3,
67
+ "use_spectral_norm": false,
68
+ "gin_channels": 256
69
+ },
70
+ "version": "2.0.1"
71
+ }
configs/configs_jp_extra.json ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 42,
6
+ "epochs": 1000,
7
+ "learning_rate": 0.0001,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 24,
11
+ "bf16_run": false,
12
+ "fp16_run": false,
13
+ "lr_decay": 0.99996,
14
+ "segment_size": 16384,
15
+ "init_lr_ratio": 1,
16
+ "warmup_epochs": 0,
17
+ "c_mel": 45,
18
+ "c_kl": 1.0,
19
+ "c_commit": 100,
20
+ "skip_optimizer": true,
21
+ "freeze_ZH_bert": false,
22
+ "freeze_JP_bert": false,
23
+ "freeze_EN_bert": false,
24
+ "freeze_emo": false,
25
+ "freeze_style": false
26
+ },
27
+ "data": {
28
+ "use_jp_extra": true,
29
+ "training_files": "filelists/train.list",
30
+ "validation_files": "filelists/val.list",
31
+ "max_wav_value": 32768.0,
32
+ "sampling_rate": 44100,
33
+ "filter_length": 2048,
34
+ "hop_length": 512,
35
+ "win_length": 2048,
36
+ "n_mel_channels": 128,
37
+ "mel_fmin": 0.0,
38
+ "mel_fmax": null,
39
+ "add_blank": true,
40
+ "n_speakers": 512,
41
+ "cleaned_text": true
42
+ },
43
+ "model": {
44
+ "use_spk_conditioned_encoder": true,
45
+ "use_noise_scaled_mas": true,
46
+ "use_mel_posterior_encoder": false,
47
+ "use_duration_discriminator": false,
48
+ "use_wavlm_discriminator": true,
49
+ "inter_channels": 192,
50
+ "hidden_channels": 192,
51
+ "filter_channels": 768,
52
+ "n_heads": 2,
53
+ "n_layers": 6,
54
+ "kernel_size": 3,
55
+ "p_dropout": 0.1,
56
+ "resblock": "1",
57
+ "resblock_kernel_sizes": [3, 7, 11],
58
+ "resblock_dilation_sizes": [
59
+ [1, 3, 5],
60
+ [1, 3, 5],
61
+ [1, 3, 5]
62
+ ],
63
+ "upsample_rates": [8, 8, 2, 2, 2],
64
+ "upsample_initial_channel": 512,
65
+ "upsample_kernel_sizes": [16, 16, 8, 2, 2],
66
+ "n_layers_q": 3,
67
+ "use_spectral_norm": false,
68
+ "gin_channels": 512,
69
+ "slm": {
70
+ "model": "./slm/wavlm-base-plus",
71
+ "sr": 16000,
72
+ "hidden": 768,
73
+ "nlayers": 13,
74
+ "initial_channel": 64
75
+ }
76
+ },
77
+ "version": "2.0.1-JP-Extra"
78
+ }
configs/paths.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Root directory of the training dataset.
2
+ # The training dataset of {model_name} should be placed in {dataset_root}/{model_name}.
3
+ dataset_root: Data
4
+
5
+ # Root directory of the model assets (for inference).
6
+ # In training, the model assets will be saved to {assets_root}/{model_name},
7
+ # and in inference, we load all the models from {assets_root}.
8
+ assets_root: model_assets
default_config.yml CHANGED
@@ -1,81 +1,81 @@
1
- # Global configuration file for Bert-VITS2
2
-
3
- model_name: "model_name"
4
-
5
- out_dir: "model_assets"
6
-
7
- # If you want to use a specific dataset path, uncomment the following line.
8
- # Otherwise, the dataset path is `Data/{model_name}`.
9
-
10
- # dataset_path: "your/dataset/path"
11
-
12
- resample:
13
- sampling_rate: 44100
14
- in_dir: "audios/raw"
15
- out_dir: "audios/wavs"
16
-
17
- preprocess_text:
18
- transcription_path: "filelists/esd.list"
19
- cleaned_path: ""
20
- train_path: "filelists/train.list"
21
- val_path: "filelists/val.list"
22
- config_path: "config.json"
23
- val_per_lang: 4
24
- max_val_total: 12
25
- clean: true
26
-
27
- bert_gen:
28
- config_path: "config.json"
29
- num_processes: 4
30
- device: "cuda"
31
- use_multi_device: false
32
-
33
- style_gen:
34
- config_path: "config.json"
35
- num_processes: 4
36
- device: "cuda"
37
-
38
- train_ms:
39
- env:
40
- MASTER_ADDR: "localhost"
41
- MASTER_PORT: 10086
42
- WORLD_SIZE: 1
43
- LOCAL_RANK: 0
44
- RANK: 0
45
- model: "models"
46
- config_path: "config.json"
47
- num_workers: 16
48
- spec_cache: True
49
- keep_ckpts: 1 # Set this to 0 to keep all checkpoints
50
-
51
- webui:
52
- # 推理设备
53
- device: "cuda"
54
- # 模型路径
55
- model: "models/G_8000.pth"
56
- # 配置文件路径
57
- config_path: "config.json"
58
- # 端口号
59
- port: 7860
60
- # 是否公开部署,对外网开放
61
- share: false
62
- # 是否开启debug模式
63
- debug: false
64
- # 语种识别库,可选langid, fastlid
65
- language_identification_library: "langid"
66
-
67
- # server_fastapi's config
68
- # TODO: `server_fastapi.py` is not implemented yet for this version
69
- server:
70
- port: 5000
71
- device: "cuda"
72
- models:
73
- - model: ""
74
- config: ""
75
- device: "cuda"
76
- language: "ZH"
77
- - model: ""
78
- config: ""
79
- device: "cpu"
80
- language: "JP"
81
- speakers: []
 
1
+ # Global configuration file for Bert-VITS2
2
+
3
+ model_name: "model_name"
4
+
5
+ out_dir: "model_assets"
6
+
7
+ # If you want to use a specific dataset path, uncomment the following line.
8
+ # Otherwise, the dataset path is `Data/{model_name}`.
9
+
10
+ # dataset_path: "your/dataset/path"
11
+
12
+ resample:
13
+ sampling_rate: 44100
14
+ in_dir: "audios/raw"
15
+ out_dir: "audios/wavs"
16
+
17
+ preprocess_text:
18
+ transcription_path: "filelists/esd.list"
19
+ cleaned_path: ""
20
+ train_path: "filelists/train.list"
21
+ val_path: "filelists/val.list"
22
+ config_path: "config.json"
23
+ val_per_lang: 4
24
+ max_val_total: 12
25
+ clean: true
26
+
27
+ bert_gen:
28
+ config_path: "config.json"
29
+ num_processes: 4
30
+ device: "cuda"
31
+ use_multi_device: false
32
+
33
+ style_gen:
34
+ config_path: "config.json"
35
+ num_processes: 4
36
+ device: "cuda"
37
+
38
+ train_ms:
39
+ env:
40
+ MASTER_ADDR: "localhost"
41
+ MASTER_PORT: 10086
42
+ WORLD_SIZE: 1
43
+ LOCAL_RANK: 0
44
+ RANK: 0
45
+ model: "models"
46
+ config_path: "config.json"
47
+ num_workers: 16
48
+ spec_cache: True
49
+ keep_ckpts: 1 # Set this to 0 to keep all checkpoints
50
+
51
+ webui:
52
+ # 推理设备
53
+ device: "cuda"
54
+ # 模型路径
55
+ model: "models/G_8000.pth"
56
+ # 配置文件路径
57
+ config_path: "config.json"
58
+ # 端口号
59
+ port: 7860
60
+ # 是否公开部署,对外网开放
61
+ share: false
62
+ # 是否开启debug模式
63
+ debug: false
64
+ # 语种识别库,可选langid, fastlid
65
+ language_identification_library: "langid"
66
+
67
+ # server_fastapi's config
68
+ # TODO: `server_fastapi.py` is not implemented yet for this version
69
+ server:
70
+ port: 5000
71
+ device: "cuda"
72
+ models:
73
+ - model: ""
74
+ config: ""
75
+ device: "cuda"
76
+ language: "ZH"
77
+ - model: ""
78
+ config: ""
79
+ device: "cpu"
80
+ language: "JP"
81
+ speakers: []
images/flare.png ADDED
images/laplus.png ADDED
images/marine.png ADDED
images/mel.png ADDED
images/noel.png ADDED
images/okayu.png ADDED
images/ririka.png ADDED
infer.py CHANGED
@@ -1,263 +1,306 @@
1
- import torch
2
-
3
- import commons
4
- import utils
5
- from models import SynthesizerTrn
6
- from text import cleaned_text_to_sequence, get_bert
7
- from text.cleaner import clean_text
8
- from text.symbols import symbols
9
-
10
- # latest_version = "1.0"
11
-
12
-
13
- def get_net_g(model_path: str, version: str, device: str, hps):
14
- net_g = SynthesizerTrn(
15
- len(symbols),
16
- hps.data.filter_length // 2 + 1,
17
- hps.train.segment_size // hps.data.hop_length,
18
- n_speakers=hps.data.n_speakers,
19
- **hps.model,
20
- ).to(device)
21
- net_g.state_dict()
22
- _ = net_g.eval()
23
- if model_path.endswith(".pth") or model_path.endswith(".pt"):
24
- _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
25
- elif model_path.endswith(".safetensors"):
26
- _ = utils.load_safetensors(model_path, net_g, device)
27
- else:
28
- raise ValueError(f"Unknown model format: {model_path}")
29
- return net_g
30
-
31
-
32
- def get_text(text, language_str, hps, device, style_text=None, style_weight=0.7):
33
- # 在此处实现当前版本的get_text
34
- norm_text, phone, tone, word2ph = clean_text(text, language_str)
35
- phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
36
-
37
- if hps.data.add_blank:
38
- phone = commons.intersperse(phone, 0)
39
- tone = commons.intersperse(tone, 0)
40
- language = commons.intersperse(language, 0)
41
- for i in range(len(word2ph)):
42
- word2ph[i] = word2ph[i] * 2
43
- word2ph[0] += 1
44
- bert_ori = get_bert(
45
- norm_text, word2ph, language_str, device, style_text, style_weight
46
- )
47
- del word2ph
48
- assert bert_ori.shape[-1] == len(phone), phone
49
-
50
- if language_str == "ZH":
51
- bert = bert_ori
52
- ja_bert = torch.zeros(1024, len(phone))
53
- en_bert = torch.zeros(1024, len(phone))
54
- elif language_str == "JP":
55
- bert = torch.zeros(1024, len(phone))
56
- ja_bert = bert_ori
57
- en_bert = torch.zeros(1024, len(phone))
58
- elif language_str == "EN":
59
- bert = torch.zeros(1024, len(phone))
60
- ja_bert = torch.zeros(1024, len(phone))
61
- en_bert = bert_ori
62
- else:
63
- raise ValueError("language_str should be ZH, JP or EN")
64
-
65
- assert bert.shape[-1] == len(
66
- phone
67
- ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
68
-
69
- phone = torch.LongTensor(phone)
70
- tone = torch.LongTensor(tone)
71
- language = torch.LongTensor(language)
72
- return bert, ja_bert, en_bert, phone, tone, language
73
-
74
-
75
- def infer(
76
- text,
77
- style_vec,
78
- sdp_ratio,
79
- noise_scale,
80
- noise_scale_w,
81
- length_scale,
82
- sid: int, # In the original Bert-VITS2, its speaker_name: str, but here it's id
83
- language,
84
- hps,
85
- net_g,
86
- device,
87
- skip_start=False,
88
- skip_end=False,
89
- style_text=None,
90
- style_weight=0.7,
91
- ):
92
- bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
93
- text,
94
- language,
95
- hps,
96
- device,
97
- style_text=style_text,
98
- style_weight=style_weight,
99
- )
100
- if skip_start:
101
- phones = phones[3:]
102
- tones = tones[3:]
103
- lang_ids = lang_ids[3:]
104
- bert = bert[:, 3:]
105
- ja_bert = ja_bert[:, 3:]
106
- en_bert = en_bert[:, 3:]
107
- if skip_end:
108
- phones = phones[:-2]
109
- tones = tones[:-2]
110
- lang_ids = lang_ids[:-2]
111
- bert = bert[:, :-2]
112
- ja_bert = ja_bert[:, :-2]
113
- en_bert = en_bert[:, :-2]
114
- with torch.no_grad():
115
- x_tst = phones.to(device).unsqueeze(0)
116
- tones = tones.to(device).unsqueeze(0)
117
- lang_ids = lang_ids.to(device).unsqueeze(0)
118
- bert = bert.to(device).unsqueeze(0)
119
- ja_bert = ja_bert.to(device).unsqueeze(0)
120
- en_bert = en_bert.to(device).unsqueeze(0)
121
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
122
- style_vec = torch.from_numpy(style_vec).to(device).unsqueeze(0)
123
- del phones
124
- sid_tensor = torch.LongTensor([sid]).to(device)
125
- audio = (
126
- net_g.infer(
127
- x_tst,
128
- x_tst_lengths,
129
- sid_tensor,
130
- tones,
131
- lang_ids,
132
- bert,
133
- ja_bert,
134
- en_bert,
135
- style_vec=style_vec,
136
- sdp_ratio=sdp_ratio,
137
- noise_scale=noise_scale,
138
- noise_scale_w=noise_scale_w,
139
- length_scale=length_scale,
140
- )[0][0, 0]
141
- .data.cpu()
142
- .float()
143
- .numpy()
144
- )
145
- del (
146
- x_tst,
147
- tones,
148
- lang_ids,
149
- bert,
150
- x_tst_lengths,
151
- sid_tensor,
152
- ja_bert,
153
- en_bert,
154
- style_vec,
155
- ) # , emo
156
- if torch.cuda.is_available():
157
- torch.cuda.empty_cache()
158
- return audio
159
-
160
-
161
- def infer_multilang(
162
- text,
163
- style_vec,
164
- sdp_ratio,
165
- noise_scale,
166
- noise_scale_w,
167
- length_scale,
168
- sid,
169
- language,
170
- hps,
171
- net_g,
172
- device,
173
- skip_start=False,
174
- skip_end=False,
175
- ):
176
- bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
177
- # emo = get_emo_(reference_audio, emotion, sid)
178
- # if isinstance(reference_audio, np.ndarray):
179
- # emo = get_clap_audio_feature(reference_audio, device)
180
- # else:
181
- # emo = get_clap_text_feature(emotion, device)
182
- # emo = torch.squeeze(emo, dim=1)
183
- for idx, (txt, lang) in enumerate(zip(text, language)):
184
- _skip_start = (idx != 0) or (skip_start and idx == 0)
185
- _skip_end = (idx != len(language) - 1) or skip_end
186
- (
187
- temp_bert,
188
- temp_ja_bert,
189
- temp_en_bert,
190
- temp_phones,
191
- temp_tones,
192
- temp_lang_ids,
193
- ) = get_text(txt, lang, hps, device)
194
- if _skip_start:
195
- temp_bert = temp_bert[:, 3:]
196
- temp_ja_bert = temp_ja_bert[:, 3:]
197
- temp_en_bert = temp_en_bert[:, 3:]
198
- temp_phones = temp_phones[3:]
199
- temp_tones = temp_tones[3:]
200
- temp_lang_ids = temp_lang_ids[3:]
201
- if _skip_end:
202
- temp_bert = temp_bert[:, :-2]
203
- temp_ja_bert = temp_ja_bert[:, :-2]
204
- temp_en_bert = temp_en_bert[:, :-2]
205
- temp_phones = temp_phones[:-2]
206
- temp_tones = temp_tones[:-2]
207
- temp_lang_ids = temp_lang_ids[:-2]
208
- bert.append(temp_bert)
209
- ja_bert.append(temp_ja_bert)
210
- en_bert.append(temp_en_bert)
211
- phones.append(temp_phones)
212
- tones.append(temp_tones)
213
- lang_ids.append(temp_lang_ids)
214
- bert = torch.concatenate(bert, dim=1)
215
- ja_bert = torch.concatenate(ja_bert, dim=1)
216
- en_bert = torch.concatenate(en_bert, dim=1)
217
- phones = torch.concatenate(phones, dim=0)
218
- tones = torch.concatenate(tones, dim=0)
219
- lang_ids = torch.concatenate(lang_ids, dim=0)
220
- with torch.no_grad():
221
- x_tst = phones.to(device).unsqueeze(0)
222
- tones = tones.to(device).unsqueeze(0)
223
- lang_ids = lang_ids.to(device).unsqueeze(0)
224
- bert = bert.to(device).unsqueeze(0)
225
- ja_bert = ja_bert.to(device).unsqueeze(0)
226
- en_bert = en_bert.to(device).unsqueeze(0)
227
- # emo = emo.to(device).unsqueeze(0)
228
- x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
229
- del phones
230
- speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
231
- audio = (
232
- net_g.infer(
233
- x_tst,
234
- x_tst_lengths,
235
- speakers,
236
- tones,
237
- lang_ids,
238
- bert,
239
- ja_bert,
240
- en_bert,
241
- style_vec=style_vec,
242
- sdp_ratio=sdp_ratio,
243
- noise_scale=noise_scale,
244
- noise_scale_w=noise_scale_w,
245
- length_scale=length_scale,
246
- )[0][0, 0]
247
- .data.cpu()
248
- .float()
249
- .numpy()
250
- )
251
- del (
252
- x_tst,
253
- tones,
254
- lang_ids,
255
- bert,
256
- x_tst_lengths,
257
- speakers,
258
- ja_bert,
259
- en_bert,
260
- ) # , emo
261
- if torch.cuda.is_available():
262
- torch.cuda.empty_cache()
263
- return audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import commons
4
+ import utils
5
+ from models import SynthesizerTrn
6
+ from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra
7
+ from text import cleaned_text_to_sequence, get_bert
8
+ from text.cleaner import clean_text
9
+ from text.symbols import symbols
10
+ from common.log import logger
11
+
12
+
13
+ class InvalidToneError(ValueError):
14
+ pass
15
+
16
+
17
+ def get_net_g(model_path: str, version: str, device: str, hps):
18
+ if version.endswith("JP-Extra"):
19
+ #logger.info("Using JP-Extra model")
20
+ net_g = SynthesizerTrnJPExtra(
21
+ len(symbols),
22
+ hps.data.filter_length // 2 + 1,
23
+ hps.train.segment_size // hps.data.hop_length,
24
+ n_speakers=hps.data.n_speakers,
25
+ **hps.model,
26
+ ).to(device)
27
+ else:
28
+ #logger.info("Using normal model")
29
+ net_g = SynthesizerTrn(
30
+ len(symbols),
31
+ hps.data.filter_length // 2 + 1,
32
+ hps.train.segment_size // hps.data.hop_length,
33
+ n_speakers=hps.data.n_speakers,
34
+ **hps.model,
35
+ ).to(device)
36
+ net_g.state_dict()
37
+ _ = net_g.eval()
38
+ if model_path.endswith(".pth") or model_path.endswith(".pt"):
39
+ _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
40
+ elif model_path.endswith(".safetensors"):
41
+ _ = utils.load_safetensors(model_path, net_g, True)
42
+ else:
43
+ raise ValueError(f"Unknown model format: {model_path}")
44
+ return net_g
45
+
46
+
47
+ def get_text(
48
+ text,
49
+ language_str,
50
+ hps,
51
+ device,
52
+ assist_text=None,
53
+ assist_text_weight=0.7,
54
+ given_tone=None,
55
+ ):
56
+ use_jp_extra = hps.version.endswith("JP-Extra")
57
+ norm_text, phone, tone, word2ph = clean_text(text, language_str, use_jp_extra)
58
+ if given_tone is not None:
59
+ if len(given_tone) != len(phone):
60
+ raise InvalidToneError(
61
+ f"Length of given_tone ({len(given_tone)}) != length of phone ({len(phone)})"
62
+ )
63
+ tone = given_tone
64
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
65
+
66
+ if hps.data.add_blank:
67
+ phone = commons.intersperse(phone, 0)
68
+ tone = commons.intersperse(tone, 0)
69
+ language = commons.intersperse(language, 0)
70
+ for i in range(len(word2ph)):
71
+ word2ph[i] = word2ph[i] * 2
72
+ word2ph[0] += 1
73
+ bert_ori = get_bert(
74
+ norm_text, word2ph, language_str, device, assist_text, assist_text_weight
75
+ )
76
+ del word2ph
77
+ assert bert_ori.shape[-1] == len(phone), phone
78
+
79
+ if language_str == "ZH":
80
+ bert = bert_ori
81
+ ja_bert = torch.zeros(1024, len(phone))
82
+ en_bert = torch.zeros(1024, len(phone))
83
+ elif language_str == "JP":
84
+ bert = torch.zeros(1024, len(phone))
85
+ ja_bert = bert_ori
86
+ en_bert = torch.zeros(1024, len(phone))
87
+ elif language_str == "EN":
88
+ bert = torch.zeros(1024, len(phone))
89
+ ja_bert = torch.zeros(1024, len(phone))
90
+ en_bert = bert_ori
91
+ else:
92
+ raise ValueError("language_str should be ZH, JP or EN")
93
+
94
+ assert bert.shape[-1] == len(
95
+ phone
96
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
97
+
98
+ phone = torch.LongTensor(phone)
99
+ tone = torch.LongTensor(tone)
100
+ language = torch.LongTensor(language)
101
+ return bert, ja_bert, en_bert, phone, tone, language
102
+
103
+
104
+ def infer(
105
+ text,
106
+ style_vec,
107
+ sdp_ratio,
108
+ noise_scale,
109
+ noise_scale_w,
110
+ length_scale,
111
+ sid: int, # In the original Bert-VITS2, its speaker_name: str, but here it's id
112
+ language,
113
+ hps,
114
+ net_g,
115
+ device,
116
+ skip_start=False,
117
+ skip_end=False,
118
+ assist_text=None,
119
+ assist_text_weight=0.7,
120
+ given_tone=None,
121
+ ):
122
+ is_jp_extra = hps.version.endswith("JP-Extra")
123
+ bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
124
+ text,
125
+ language,
126
+ hps,
127
+ device,
128
+ assist_text=assist_text,
129
+ assist_text_weight=assist_text_weight,
130
+ given_tone=given_tone,
131
+ )
132
+ if skip_start:
133
+ phones = phones[3:]
134
+ tones = tones[3:]
135
+ lang_ids = lang_ids[3:]
136
+ bert = bert[:, 3:]
137
+ ja_bert = ja_bert[:, 3:]
138
+ en_bert = en_bert[:, 3:]
139
+ if skip_end:
140
+ phones = phones[:-2]
141
+ tones = tones[:-2]
142
+ lang_ids = lang_ids[:-2]
143
+ bert = bert[:, :-2]
144
+ ja_bert = ja_bert[:, :-2]
145
+ en_bert = en_bert[:, :-2]
146
+ with torch.no_grad():
147
+ x_tst = phones.to(device).unsqueeze(0)
148
+ tones = tones.to(device).unsqueeze(0)
149
+ lang_ids = lang_ids.to(device).unsqueeze(0)
150
+ bert = bert.to(device).unsqueeze(0)
151
+ ja_bert = ja_bert.to(device).unsqueeze(0)
152
+ en_bert = en_bert.to(device).unsqueeze(0)
153
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
154
+ style_vec = torch.from_numpy(style_vec).to(device).unsqueeze(0)
155
+ del phones
156
+ sid_tensor = torch.LongTensor([sid]).to(device)
157
+ if is_jp_extra:
158
+ output = net_g.infer(
159
+ x_tst,
160
+ x_tst_lengths,
161
+ sid_tensor,
162
+ tones,
163
+ lang_ids,
164
+ ja_bert,
165
+ style_vec=style_vec,
166
+ sdp_ratio=sdp_ratio,
167
+ noise_scale=noise_scale,
168
+ noise_scale_w=noise_scale_w,
169
+ length_scale=length_scale,
170
+ )
171
+ else:
172
+ output = net_g.infer(
173
+ x_tst,
174
+ x_tst_lengths,
175
+ sid_tensor,
176
+ tones,
177
+ lang_ids,
178
+ bert,
179
+ ja_bert,
180
+ en_bert,
181
+ style_vec=style_vec,
182
+ sdp_ratio=sdp_ratio,
183
+ noise_scale=noise_scale,
184
+ noise_scale_w=noise_scale_w,
185
+ length_scale=length_scale,
186
+ )
187
+ audio = output[0][0, 0].data.cpu().float().numpy()
188
+ del (
189
+ x_tst,
190
+ tones,
191
+ lang_ids,
192
+ bert,
193
+ x_tst_lengths,
194
+ sid_tensor,
195
+ ja_bert,
196
+ en_bert,
197
+ style_vec,
198
+ ) # , emo
199
+ if torch.cuda.is_available():
200
+ torch.cuda.empty_cache()
201
+ return audio
202
+
203
+
204
+ def infer_multilang(
205
+ text,
206
+ style_vec,
207
+ sdp_ratio,
208
+ noise_scale,
209
+ noise_scale_w,
210
+ length_scale,
211
+ sid,
212
+ language,
213
+ hps,
214
+ net_g,
215
+ device,
216
+ skip_start=False,
217
+ skip_end=False,
218
+ ):
219
+ bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
220
+ # emo = get_emo_(reference_audio, emotion, sid)
221
+ # if isinstance(reference_audio, np.ndarray):
222
+ # emo = get_clap_audio_feature(reference_audio, device)
223
+ # else:
224
+ # emo = get_clap_text_feature(emotion, device)
225
+ # emo = torch.squeeze(emo, dim=1)
226
+ for idx, (txt, lang) in enumerate(zip(text, language)):
227
+ _skip_start = (idx != 0) or (skip_start and idx == 0)
228
+ _skip_end = (idx != len(language) - 1) or skip_end
229
+ (
230
+ temp_bert,
231
+ temp_ja_bert,
232
+ temp_en_bert,
233
+ temp_phones,
234
+ temp_tones,
235
+ temp_lang_ids,
236
+ ) = get_text(txt, lang, hps, device)
237
+ if _skip_start:
238
+ temp_bert = temp_bert[:, 3:]
239
+ temp_ja_bert = temp_ja_bert[:, 3:]
240
+ temp_en_bert = temp_en_bert[:, 3:]
241
+ temp_phones = temp_phones[3:]
242
+ temp_tones = temp_tones[3:]
243
+ temp_lang_ids = temp_lang_ids[3:]
244
+ if _skip_end:
245
+ temp_bert = temp_bert[:, :-2]
246
+ temp_ja_bert = temp_ja_bert[:, :-2]
247
+ temp_en_bert = temp_en_bert[:, :-2]
248
+ temp_phones = temp_phones[:-2]
249
+ temp_tones = temp_tones[:-2]
250
+ temp_lang_ids = temp_lang_ids[:-2]
251
+ bert.append(temp_bert)
252
+ ja_bert.append(temp_ja_bert)
253
+ en_bert.append(temp_en_bert)
254
+ phones.append(temp_phones)
255
+ tones.append(temp_tones)
256
+ lang_ids.append(temp_lang_ids)
257
+ bert = torch.concatenate(bert, dim=1)
258
+ ja_bert = torch.concatenate(ja_bert, dim=1)
259
+ en_bert = torch.concatenate(en_bert, dim=1)
260
+ phones = torch.concatenate(phones, dim=0)
261
+ tones = torch.concatenate(tones, dim=0)
262
+ lang_ids = torch.concatenate(lang_ids, dim=0)
263
+ with torch.no_grad():
264
+ x_tst = phones.to(device).unsqueeze(0)
265
+ tones = tones.to(device).unsqueeze(0)
266
+ lang_ids = lang_ids.to(device).unsqueeze(0)
267
+ bert = bert.to(device).unsqueeze(0)
268
+ ja_bert = ja_bert.to(device).unsqueeze(0)
269
+ en_bert = en_bert.to(device).unsqueeze(0)
270
+ # emo = emo.to(device).unsqueeze(0)
271
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
272
+ del phones
273
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
274
+ audio = (
275
+ net_g.infer(
276
+ x_tst,
277
+ x_tst_lengths,
278
+ speakers,
279
+ tones,
280
+ lang_ids,
281
+ bert,
282
+ ja_bert,
283
+ en_bert,
284
+ style_vec=style_vec,
285
+ sdp_ratio=sdp_ratio,
286
+ noise_scale=noise_scale,
287
+ noise_scale_w=noise_scale_w,
288
+ length_scale=length_scale,
289
+ )[0][0, 0]
290
+ .data.cpu()
291
+ .float()
292
+ .numpy()
293
+ )
294
+ del (
295
+ x_tst,
296
+ tones,
297
+ lang_ids,
298
+ bert,
299
+ x_tst_lengths,
300
+ speakers,
301
+ ja_bert,
302
+ en_bert,
303
+ ) # , emo
304
+ if torch.cuda.is_available():
305
+ torch.cuda.empty_cache()
306
+ return audio
models_jp_extra.py ADDED
@@ -0,0 +1,1071 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ import attentions
9
+ import monotonic_align
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+
14
+ from commons import init_weights, get_padding
15
+ from text import symbols, num_tones, num_languages
16
+
17
+
18
+ class DurationDiscriminator(nn.Module): # vits2
19
+ def __init__(
20
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
21
+ ):
22
+ super().__init__()
23
+
24
+ self.in_channels = in_channels
25
+ self.filter_channels = filter_channels
26
+ self.kernel_size = kernel_size
27
+ self.p_dropout = p_dropout
28
+ self.gin_channels = gin_channels
29
+
30
+ self.drop = nn.Dropout(p_dropout)
31
+ self.conv_1 = nn.Conv1d(
32
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
33
+ )
34
+ self.norm_1 = modules.LayerNorm(filter_channels)
35
+ self.conv_2 = nn.Conv1d(
36
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
37
+ )
38
+ self.norm_2 = modules.LayerNorm(filter_channels)
39
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
40
+
41
+ self.LSTM = nn.LSTM(
42
+ 2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
43
+ )
44
+
45
+ if gin_channels != 0:
46
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
47
+
48
+ self.output_layer = nn.Sequential(
49
+ nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
50
+ )
51
+
52
+ def forward_probability(self, x, dur):
53
+ dur = self.dur_proj(dur)
54
+ x = torch.cat([x, dur], dim=1)
55
+ x = x.transpose(1, 2)
56
+ x, _ = self.LSTM(x)
57
+ output_prob = self.output_layer(x)
58
+ return output_prob
59
+
60
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
61
+ x = torch.detach(x)
62
+ if g is not None:
63
+ g = torch.detach(g)
64
+ x = x + self.cond(g)
65
+ x = self.conv_1(x * x_mask)
66
+ x = torch.relu(x)
67
+ x = self.norm_1(x)
68
+ x = self.drop(x)
69
+ x = self.conv_2(x * x_mask)
70
+ x = torch.relu(x)
71
+ x = self.norm_2(x)
72
+ x = self.drop(x)
73
+
74
+ output_probs = []
75
+ for dur in [dur_r, dur_hat]:
76
+ output_prob = self.forward_probability(x, dur)
77
+ output_probs.append(output_prob)
78
+
79
+ return output_probs
80
+
81
+
82
+ class TransformerCouplingBlock(nn.Module):
83
+ def __init__(
84
+ self,
85
+ channels,
86
+ hidden_channels,
87
+ filter_channels,
88
+ n_heads,
89
+ n_layers,
90
+ kernel_size,
91
+ p_dropout,
92
+ n_flows=4,
93
+ gin_channels=0,
94
+ share_parameter=False,
95
+ ):
96
+ super().__init__()
97
+ self.channels = channels
98
+ self.hidden_channels = hidden_channels
99
+ self.kernel_size = kernel_size
100
+ self.n_layers = n_layers
101
+ self.n_flows = n_flows
102
+ self.gin_channels = gin_channels
103
+
104
+ self.flows = nn.ModuleList()
105
+
106
+ self.wn = (
107
+ attentions.FFT(
108
+ hidden_channels,
109
+ filter_channels,
110
+ n_heads,
111
+ n_layers,
112
+ kernel_size,
113
+ p_dropout,
114
+ isflow=True,
115
+ gin_channels=self.gin_channels,
116
+ )
117
+ if share_parameter
118
+ else None
119
+ )
120
+
121
+ for i in range(n_flows):
122
+ self.flows.append(
123
+ modules.TransformerCouplingLayer(
124
+ channels,
125
+ hidden_channels,
126
+ kernel_size,
127
+ n_layers,
128
+ n_heads,
129
+ p_dropout,
130
+ filter_channels,
131
+ mean_only=True,
132
+ wn_sharing_parameter=self.wn,
133
+ gin_channels=self.gin_channels,
134
+ )
135
+ )
136
+ self.flows.append(modules.Flip())
137
+
138
+ def forward(self, x, x_mask, g=None, reverse=False):
139
+ if not reverse:
140
+ for flow in self.flows:
141
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
142
+ else:
143
+ for flow in reversed(self.flows):
144
+ x = flow(x, x_mask, g=g, reverse=reverse)
145
+ return x
146
+
147
+
148
+ class StochasticDurationPredictor(nn.Module):
149
+ def __init__(
150
+ self,
151
+ in_channels,
152
+ filter_channels,
153
+ kernel_size,
154
+ p_dropout,
155
+ n_flows=4,
156
+ gin_channels=0,
157
+ ):
158
+ super().__init__()
159
+ filter_channels = in_channels # it needs to be removed from future version.
160
+ self.in_channels = in_channels
161
+ self.filter_channels = filter_channels
162
+ self.kernel_size = kernel_size
163
+ self.p_dropout = p_dropout
164
+ self.n_flows = n_flows
165
+ self.gin_channels = gin_channels
166
+
167
+ self.log_flow = modules.Log()
168
+ self.flows = nn.ModuleList()
169
+ self.flows.append(modules.ElementwiseAffine(2))
170
+ for i in range(n_flows):
171
+ self.flows.append(
172
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
173
+ )
174
+ self.flows.append(modules.Flip())
175
+
176
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
177
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
178
+ self.post_convs = modules.DDSConv(
179
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
180
+ )
181
+ self.post_flows = nn.ModuleList()
182
+ self.post_flows.append(modules.ElementwiseAffine(2))
183
+ for i in range(4):
184
+ self.post_flows.append(
185
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
186
+ )
187
+ self.post_flows.append(modules.Flip())
188
+
189
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
190
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
191
+ self.convs = modules.DDSConv(
192
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
193
+ )
194
+ if gin_channels != 0:
195
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
196
+
197
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
198
+ x = torch.detach(x)
199
+ x = self.pre(x)
200
+ if g is not None:
201
+ g = torch.detach(g)
202
+ x = x + self.cond(g)
203
+ x = self.convs(x, x_mask)
204
+ x = self.proj(x) * x_mask
205
+
206
+ if not reverse:
207
+ flows = self.flows
208
+ assert w is not None
209
+
210
+ logdet_tot_q = 0
211
+ h_w = self.post_pre(w)
212
+ h_w = self.post_convs(h_w, x_mask)
213
+ h_w = self.post_proj(h_w) * x_mask
214
+ e_q = (
215
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
216
+ * x_mask
217
+ )
218
+ z_q = e_q
219
+ for flow in self.post_flows:
220
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
221
+ logdet_tot_q += logdet_q
222
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
223
+ u = torch.sigmoid(z_u) * x_mask
224
+ z0 = (w - u) * x_mask
225
+ logdet_tot_q += torch.sum(
226
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
227
+ )
228
+ logq = (
229
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
230
+ - logdet_tot_q
231
+ )
232
+
233
+ logdet_tot = 0
234
+ z0, logdet = self.log_flow(z0, x_mask)
235
+ logdet_tot += logdet
236
+ z = torch.cat([z0, z1], 1)
237
+ for flow in flows:
238
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
239
+ logdet_tot = logdet_tot + logdet
240
+ nll = (
241
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
242
+ - logdet_tot
243
+ )
244
+ return nll + logq # [b]
245
+ else:
246
+ flows = list(reversed(self.flows))
247
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
248
+ z = (
249
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
250
+ * noise_scale
251
+ )
252
+ for flow in flows:
253
+ z = flow(z, x_mask, g=x, reverse=reverse)
254
+ z0, z1 = torch.split(z, [1, 1], 1)
255
+ logw = z0
256
+ return logw
257
+
258
+
259
+ class DurationPredictor(nn.Module):
260
+ def __init__(
261
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
262
+ ):
263
+ super().__init__()
264
+
265
+ self.in_channels = in_channels
266
+ self.filter_channels = filter_channels
267
+ self.kernel_size = kernel_size
268
+ self.p_dropout = p_dropout
269
+ self.gin_channels = gin_channels
270
+
271
+ self.drop = nn.Dropout(p_dropout)
272
+ self.conv_1 = nn.Conv1d(
273
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
274
+ )
275
+ self.norm_1 = modules.LayerNorm(filter_channels)
276
+ self.conv_2 = nn.Conv1d(
277
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
278
+ )
279
+ self.norm_2 = modules.LayerNorm(filter_channels)
280
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
281
+
282
+ if gin_channels != 0:
283
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
284
+
285
+ def forward(self, x, x_mask, g=None):
286
+ x = torch.detach(x)
287
+ if g is not None:
288
+ g = torch.detach(g)
289
+ x = x + self.cond(g)
290
+ x = self.conv_1(x * x_mask)
291
+ x = torch.relu(x)
292
+ x = self.norm_1(x)
293
+ x = self.drop(x)
294
+ x = self.conv_2(x * x_mask)
295
+ x = torch.relu(x)
296
+ x = self.norm_2(x)
297
+ x = self.drop(x)
298
+ x = self.proj(x * x_mask)
299
+ return x * x_mask
300
+
301
+
302
+ class Bottleneck(nn.Sequential):
303
+ def __init__(self, in_dim, hidden_dim):
304
+ c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
305
+ c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
306
+ super().__init__(*[c_fc1, c_fc2])
307
+
308
+
309
+ class Block(nn.Module):
310
+ def __init__(self, in_dim, hidden_dim) -> None:
311
+ super().__init__()
312
+ self.norm = nn.LayerNorm(in_dim)
313
+ self.mlp = MLP(in_dim, hidden_dim)
314
+
315
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
316
+ x = x + self.mlp(self.norm(x))
317
+ return x
318
+
319
+
320
+ class MLP(nn.Module):
321
+ def __init__(self, in_dim, hidden_dim):
322
+ super().__init__()
323
+ self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
324
+ self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
325
+ self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
326
+
327
+ def forward(self, x: torch.Tensor):
328
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
329
+ x = self.c_proj(x)
330
+ return x
331
+
332
+
333
+ class TextEncoder(nn.Module):
334
+ def __init__(
335
+ self,
336
+ n_vocab,
337
+ out_channels,
338
+ hidden_channels,
339
+ filter_channels,
340
+ n_heads,
341
+ n_layers,
342
+ kernel_size,
343
+ p_dropout,
344
+ gin_channels=0,
345
+ ):
346
+ super().__init__()
347
+ self.n_vocab = n_vocab
348
+ self.out_channels = out_channels
349
+ self.hidden_channels = hidden_channels
350
+ self.filter_channels = filter_channels
351
+ self.n_heads = n_heads
352
+ self.n_layers = n_layers
353
+ self.kernel_size = kernel_size
354
+ self.p_dropout = p_dropout
355
+ self.gin_channels = gin_channels
356
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
357
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
358
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
359
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
360
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
361
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
362
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
363
+
364
+ # Remove emo_vq since it's not working well.
365
+ self.style_proj = nn.Linear(256, hidden_channels)
366
+
367
+ self.encoder = attentions.Encoder(
368
+ hidden_channels,
369
+ filter_channels,
370
+ n_heads,
371
+ n_layers,
372
+ kernel_size,
373
+ p_dropout,
374
+ gin_channels=self.gin_channels,
375
+ )
376
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
377
+
378
+ def forward(self, x, x_lengths, tone, language, bert, style_vec, g=None):
379
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
380
+ style_emb = self.style_proj(style_vec.unsqueeze(1))
381
+ x = (
382
+ self.emb(x)
383
+ + self.tone_emb(tone)
384
+ + self.language_emb(language)
385
+ + bert_emb
386
+ + style_emb
387
+ ) * math.sqrt(
388
+ self.hidden_channels
389
+ ) # [b, t, h]
390
+ x = torch.transpose(x, 1, -1) # [b, h, t]
391
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
392
+ x.dtype
393
+ )
394
+
395
+ x = self.encoder(x * x_mask, x_mask, g=g)
396
+ stats = self.proj(x) * x_mask
397
+
398
+ m, logs = torch.split(stats, self.out_channels, dim=1)
399
+ return x, m, logs, x_mask
400
+
401
+
402
+ class ResidualCouplingBlock(nn.Module):
403
+ def __init__(
404
+ self,
405
+ channels,
406
+ hidden_channels,
407
+ kernel_size,
408
+ dilation_rate,
409
+ n_layers,
410
+ n_flows=4,
411
+ gin_channels=0,
412
+ ):
413
+ super().__init__()
414
+ self.channels = channels
415
+ self.hidden_channels = hidden_channels
416
+ self.kernel_size = kernel_size
417
+ self.dilation_rate = dilation_rate
418
+ self.n_layers = n_layers
419
+ self.n_flows = n_flows
420
+ self.gin_channels = gin_channels
421
+
422
+ self.flows = nn.ModuleList()
423
+ for i in range(n_flows):
424
+ self.flows.append(
425
+ modules.ResidualCouplingLayer(
426
+ channels,
427
+ hidden_channels,
428
+ kernel_size,
429
+ dilation_rate,
430
+ n_layers,
431
+ gin_channels=gin_channels,
432
+ mean_only=True,
433
+ )
434
+ )
435
+ self.flows.append(modules.Flip())
436
+
437
+ def forward(self, x, x_mask, g=None, reverse=False):
438
+ if not reverse:
439
+ for flow in self.flows:
440
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
441
+ else:
442
+ for flow in reversed(self.flows):
443
+ x = flow(x, x_mask, g=g, reverse=reverse)
444
+ return x
445
+
446
+
447
+ class PosteriorEncoder(nn.Module):
448
+ def __init__(
449
+ self,
450
+ in_channels,
451
+ out_channels,
452
+ hidden_channels,
453
+ kernel_size,
454
+ dilation_rate,
455
+ n_layers,
456
+ gin_channels=0,
457
+ ):
458
+ super().__init__()
459
+ self.in_channels = in_channels
460
+ self.out_channels = out_channels
461
+ self.hidden_channels = hidden_channels
462
+ self.kernel_size = kernel_size
463
+ self.dilation_rate = dilation_rate
464
+ self.n_layers = n_layers
465
+ self.gin_channels = gin_channels
466
+
467
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
468
+ self.enc = modules.WN(
469
+ hidden_channels,
470
+ kernel_size,
471
+ dilation_rate,
472
+ n_layers,
473
+ gin_channels=gin_channels,
474
+ )
475
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
476
+
477
+ def forward(self, x, x_lengths, g=None):
478
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
479
+ x.dtype
480
+ )
481
+ x = self.pre(x) * x_mask
482
+ x = self.enc(x, x_mask, g=g)
483
+ stats = self.proj(x) * x_mask
484
+ m, logs = torch.split(stats, self.out_channels, dim=1)
485
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
486
+ return z, m, logs, x_mask
487
+
488
+
489
+ class Generator(torch.nn.Module):
490
+ def __init__(
491
+ self,
492
+ initial_channel,
493
+ resblock,
494
+ resblock_kernel_sizes,
495
+ resblock_dilation_sizes,
496
+ upsample_rates,
497
+ upsample_initial_channel,
498
+ upsample_kernel_sizes,
499
+ gin_channels=0,
500
+ ):
501
+ super(Generator, self).__init__()
502
+ self.num_kernels = len(resblock_kernel_sizes)
503
+ self.num_upsamples = len(upsample_rates)
504
+ self.conv_pre = Conv1d(
505
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
506
+ )
507
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
508
+
509
+ self.ups = nn.ModuleList()
510
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
511
+ self.ups.append(
512
+ weight_norm(
513
+ ConvTranspose1d(
514
+ upsample_initial_channel // (2**i),
515
+ upsample_initial_channel // (2 ** (i + 1)),
516
+ k,
517
+ u,
518
+ padding=(k - u) // 2,
519
+ )
520
+ )
521
+ )
522
+
523
+ self.resblocks = nn.ModuleList()
524
+ for i in range(len(self.ups)):
525
+ ch = upsample_initial_channel // (2 ** (i + 1))
526
+ for j, (k, d) in enumerate(
527
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
528
+ ):
529
+ self.resblocks.append(resblock(ch, k, d))
530
+
531
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
532
+ self.ups.apply(init_weights)
533
+
534
+ if gin_channels != 0:
535
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
536
+
537
+ def forward(self, x, g=None):
538
+ x = self.conv_pre(x)
539
+ if g is not None:
540
+ x = x + self.cond(g)
541
+
542
+ for i in range(self.num_upsamples):
543
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
544
+ x = self.ups[i](x)
545
+ xs = None
546
+ for j in range(self.num_kernels):
547
+ if xs is None:
548
+ xs = self.resblocks[i * self.num_kernels + j](x)
549
+ else:
550
+ xs += self.resblocks[i * self.num_kernels + j](x)
551
+ x = xs / self.num_kernels
552
+ x = F.leaky_relu(x)
553
+ x = self.conv_post(x)
554
+ x = torch.tanh(x)
555
+
556
+ return x
557
+
558
+ def remove_weight_norm(self):
559
+ print("Removing weight norm...")
560
+ for layer in self.ups:
561
+ remove_weight_norm(layer)
562
+ for layer in self.resblocks:
563
+ layer.remove_weight_norm()
564
+
565
+
566
+ class DiscriminatorP(torch.nn.Module):
567
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
568
+ super(DiscriminatorP, self).__init__()
569
+ self.period = period
570
+ self.use_spectral_norm = use_spectral_norm
571
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
572
+ self.convs = nn.ModuleList(
573
+ [
574
+ norm_f(
575
+ Conv2d(
576
+ 1,
577
+ 32,
578
+ (kernel_size, 1),
579
+ (stride, 1),
580
+ padding=(get_padding(kernel_size, 1), 0),
581
+ )
582
+ ),
583
+ norm_f(
584
+ Conv2d(
585
+ 32,
586
+ 128,
587
+ (kernel_size, 1),
588
+ (stride, 1),
589
+ padding=(get_padding(kernel_size, 1), 0),
590
+ )
591
+ ),
592
+ norm_f(
593
+ Conv2d(
594
+ 128,
595
+ 512,
596
+ (kernel_size, 1),
597
+ (stride, 1),
598
+ padding=(get_padding(kernel_size, 1), 0),
599
+ )
600
+ ),
601
+ norm_f(
602
+ Conv2d(
603
+ 512,
604
+ 1024,
605
+ (kernel_size, 1),
606
+ (stride, 1),
607
+ padding=(get_padding(kernel_size, 1), 0),
608
+ )
609
+ ),
610
+ norm_f(
611
+ Conv2d(
612
+ 1024,
613
+ 1024,
614
+ (kernel_size, 1),
615
+ 1,
616
+ padding=(get_padding(kernel_size, 1), 0),
617
+ )
618
+ ),
619
+ ]
620
+ )
621
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
622
+
623
+ def forward(self, x):
624
+ fmap = []
625
+
626
+ # 1d to 2d
627
+ b, c, t = x.shape
628
+ if t % self.period != 0: # pad first
629
+ n_pad = self.period - (t % self.period)
630
+ x = F.pad(x, (0, n_pad), "reflect")
631
+ t = t + n_pad
632
+ x = x.view(b, c, t // self.period, self.period)
633
+
634
+ for layer in self.convs:
635
+ x = layer(x)
636
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
637
+ fmap.append(x)
638
+ x = self.conv_post(x)
639
+ fmap.append(x)
640
+ x = torch.flatten(x, 1, -1)
641
+
642
+ return x, fmap
643
+
644
+
645
+ class DiscriminatorS(torch.nn.Module):
646
+ def __init__(self, use_spectral_norm=False):
647
+ super(DiscriminatorS, self).__init__()
648
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
649
+ self.convs = nn.ModuleList(
650
+ [
651
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
652
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
653
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
654
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
655
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
656
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
657
+ ]
658
+ )
659
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
660
+
661
+ def forward(self, x):
662
+ fmap = []
663
+
664
+ for layer in self.convs:
665
+ x = layer(x)
666
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
667
+ fmap.append(x)
668
+ x = self.conv_post(x)
669
+ fmap.append(x)
670
+ x = torch.flatten(x, 1, -1)
671
+
672
+ return x, fmap
673
+
674
+
675
+ class MultiPeriodDiscriminator(torch.nn.Module):
676
+ def __init__(self, use_spectral_norm=False):
677
+ super(MultiPeriodDiscriminator, self).__init__()
678
+ periods = [2, 3, 5, 7, 11]
679
+
680
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
681
+ discs = discs + [
682
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
683
+ ]
684
+ self.discriminators = nn.ModuleList(discs)
685
+
686
+ def forward(self, y, y_hat):
687
+ y_d_rs = []
688
+ y_d_gs = []
689
+ fmap_rs = []
690
+ fmap_gs = []
691
+ for i, d in enumerate(self.discriminators):
692
+ y_d_r, fmap_r = d(y)
693
+ y_d_g, fmap_g = d(y_hat)
694
+ y_d_rs.append(y_d_r)
695
+ y_d_gs.append(y_d_g)
696
+ fmap_rs.append(fmap_r)
697
+ fmap_gs.append(fmap_g)
698
+
699
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
700
+
701
+
702
+ class WavLMDiscriminator(nn.Module):
703
+ """docstring for Discriminator."""
704
+
705
+ def __init__(
706
+ self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
707
+ ):
708
+ super(WavLMDiscriminator, self).__init__()
709
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
710
+ self.pre = norm_f(
711
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
712
+ )
713
+
714
+ self.convs = nn.ModuleList(
715
+ [
716
+ norm_f(
717
+ nn.Conv1d(
718
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
719
+ )
720
+ ),
721
+ norm_f(
722
+ nn.Conv1d(
723
+ initial_channel * 2,
724
+ initial_channel * 4,
725
+ kernel_size=5,
726
+ padding=2,
727
+ )
728
+ ),
729
+ norm_f(
730
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
731
+ ),
732
+ ]
733
+ )
734
+
735
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
736
+
737
+ def forward(self, x):
738
+ x = self.pre(x)
739
+
740
+ fmap = []
741
+ for l in self.convs:
742
+ x = l(x)
743
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
744
+ fmap.append(x)
745
+ x = self.conv_post(x)
746
+ x = torch.flatten(x, 1, -1)
747
+
748
+ return x
749
+
750
+
751
+ class ReferenceEncoder(nn.Module):
752
+ """
753
+ inputs --- [N, Ty/r, n_mels*r] mels
754
+ outputs --- [N, ref_enc_gru_size]
755
+ """
756
+
757
+ def __init__(self, spec_channels, gin_channels=0):
758
+ super().__init__()
759
+ self.spec_channels = spec_channels
760
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
761
+ K = len(ref_enc_filters)
762
+ filters = [1] + ref_enc_filters
763
+ convs = [
764
+ weight_norm(
765
+ nn.Conv2d(
766
+ in_channels=filters[i],
767
+ out_channels=filters[i + 1],
768
+ kernel_size=(3, 3),
769
+ stride=(2, 2),
770
+ padding=(1, 1),
771
+ )
772
+ )
773
+ for i in range(K)
774
+ ]
775
+ self.convs = nn.ModuleList(convs)
776
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
777
+
778
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
779
+ self.gru = nn.GRU(
780
+ input_size=ref_enc_filters[-1] * out_channels,
781
+ hidden_size=256 // 2,
782
+ batch_first=True,
783
+ )
784
+ self.proj = nn.Linear(128, gin_channels)
785
+
786
+ def forward(self, inputs, mask=None):
787
+ N = inputs.size(0)
788
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
789
+ for conv in self.convs:
790
+ out = conv(out)
791
+ # out = wn(out)
792
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
793
+
794
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
795
+ T = out.size(1)
796
+ N = out.size(0)
797
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
798
+
799
+ self.gru.flatten_parameters()
800
+ memory, out = self.gru(out) # out --- [1, N, 128]
801
+
802
+ return self.proj(out.squeeze(0))
803
+
804
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
805
+ for i in range(n_convs):
806
+ L = (L - kernel_size + 2 * pad) // stride + 1
807
+ return L
808
+
809
+
810
+ class SynthesizerTrn(nn.Module):
811
+ """
812
+ Synthesizer for Training
813
+ """
814
+
815
+ def __init__(
816
+ self,
817
+ n_vocab,
818
+ spec_channels,
819
+ segment_size,
820
+ inter_channels,
821
+ hidden_channels,
822
+ filter_channels,
823
+ n_heads,
824
+ n_layers,
825
+ kernel_size,
826
+ p_dropout,
827
+ resblock,
828
+ resblock_kernel_sizes,
829
+ resblock_dilation_sizes,
830
+ upsample_rates,
831
+ upsample_initial_channel,
832
+ upsample_kernel_sizes,
833
+ n_speakers=256,
834
+ gin_channels=256,
835
+ use_sdp=True,
836
+ n_flow_layer=4,
837
+ n_layers_trans_flow=6,
838
+ flow_share_parameter=False,
839
+ use_transformer_flow=True,
840
+ **kwargs
841
+ ):
842
+ super().__init__()
843
+ self.n_vocab = n_vocab
844
+ self.spec_channels = spec_channels
845
+ self.inter_channels = inter_channels
846
+ self.hidden_channels = hidden_channels
847
+ self.filter_channels = filter_channels
848
+ self.n_heads = n_heads
849
+ self.n_layers = n_layers
850
+ self.kernel_size = kernel_size
851
+ self.p_dropout = p_dropout
852
+ self.resblock = resblock
853
+ self.resblock_kernel_sizes = resblock_kernel_sizes
854
+ self.resblock_dilation_sizes = resblock_dilation_sizes
855
+ self.upsample_rates = upsample_rates
856
+ self.upsample_initial_channel = upsample_initial_channel
857
+ self.upsample_kernel_sizes = upsample_kernel_sizes
858
+ self.segment_size = segment_size
859
+ self.n_speakers = n_speakers
860
+ self.gin_channels = gin_channels
861
+ self.n_layers_trans_flow = n_layers_trans_flow
862
+ self.use_spk_conditioned_encoder = kwargs.get(
863
+ "use_spk_conditioned_encoder", True
864
+ )
865
+ self.use_sdp = use_sdp
866
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
867
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
868
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
869
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
870
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
871
+ self.enc_gin_channels = gin_channels
872
+ self.enc_p = TextEncoder(
873
+ n_vocab,
874
+ inter_channels,
875
+ hidden_channels,
876
+ filter_channels,
877
+ n_heads,
878
+ n_layers,
879
+ kernel_size,
880
+ p_dropout,
881
+ gin_channels=self.enc_gin_channels,
882
+ )
883
+ self.dec = Generator(
884
+ inter_channels,
885
+ resblock,
886
+ resblock_kernel_sizes,
887
+ resblock_dilation_sizes,
888
+ upsample_rates,
889
+ upsample_initial_channel,
890
+ upsample_kernel_sizes,
891
+ gin_channels=gin_channels,
892
+ )
893
+ self.enc_q = PosteriorEncoder(
894
+ spec_channels,
895
+ inter_channels,
896
+ hidden_channels,
897
+ 5,
898
+ 1,
899
+ 16,
900
+ gin_channels=gin_channels,
901
+ )
902
+ if use_transformer_flow:
903
+ self.flow = TransformerCouplingBlock(
904
+ inter_channels,
905
+ hidden_channels,
906
+ filter_channels,
907
+ n_heads,
908
+ n_layers_trans_flow,
909
+ 5,
910
+ p_dropout,
911
+ n_flow_layer,
912
+ gin_channels=gin_channels,
913
+ share_parameter=flow_share_parameter,
914
+ )
915
+ else:
916
+ self.flow = ResidualCouplingBlock(
917
+ inter_channels,
918
+ hidden_channels,
919
+ 5,
920
+ 1,
921
+ n_flow_layer,
922
+ gin_channels=gin_channels,
923
+ )
924
+ self.sdp = StochasticDurationPredictor(
925
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
926
+ )
927
+ self.dp = DurationPredictor(
928
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
929
+ )
930
+
931
+ if n_speakers >= 1:
932
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
933
+ else:
934
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
935
+
936
+ def forward(
937
+ self,
938
+ x,
939
+ x_lengths,
940
+ y,
941
+ y_lengths,
942
+ sid,
943
+ tone,
944
+ language,
945
+ bert,
946
+ style_vec,
947
+ ):
948
+ if self.n_speakers > 0:
949
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
950
+ else:
951
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
952
+ x, m_p, logs_p, x_mask = self.enc_p(
953
+ x, x_lengths, tone, language, bert, style_vec, g=g
954
+ )
955
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
956
+ z_p = self.flow(z, y_mask, g=g)
957
+
958
+ with torch.no_grad():
959
+ # negative cross-entropy
960
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
961
+ neg_cent1 = torch.sum(
962
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
963
+ ) # [b, 1, t_s]
964
+ neg_cent2 = torch.matmul(
965
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
966
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
967
+ neg_cent3 = torch.matmul(
968
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
969
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
970
+ neg_cent4 = torch.sum(
971
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
972
+ ) # [b, 1, t_s]
973
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
974
+ if self.use_noise_scaled_mas:
975
+ epsilon = (
976
+ torch.std(neg_cent)
977
+ * torch.randn_like(neg_cent)
978
+ * self.current_mas_noise_scale
979
+ )
980
+ neg_cent = neg_cent + epsilon
981
+
982
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
983
+ attn = (
984
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
985
+ .unsqueeze(1)
986
+ .detach()
987
+ )
988
+
989
+ w = attn.sum(2)
990
+
991
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
992
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
993
+
994
+ logw_ = torch.log(w + 1e-6) * x_mask
995
+ logw = self.dp(x, x_mask, g=g)
996
+ # logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
997
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
998
+ x_mask
999
+ ) # for averaging
1000
+ # l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
1001
+
1002
+ l_length = l_length_dp + l_length_sdp
1003
+
1004
+ # expand prior
1005
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
1006
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
1007
+
1008
+ z_slice, ids_slice = commons.rand_slice_segments(
1009
+ z, y_lengths, self.segment_size
1010
+ )
1011
+ o = self.dec(z_slice, g=g)
1012
+ return (
1013
+ o,
1014
+ l_length,
1015
+ attn,
1016
+ ids_slice,
1017
+ x_mask,
1018
+ y_mask,
1019
+ (z, z_p, m_p, logs_p, m_q, logs_q),
1020
+ (x, logw, logw_), # , logw_sdp),
1021
+ g,
1022
+ )
1023
+
1024
+ def infer(
1025
+ self,
1026
+ x,
1027
+ x_lengths,
1028
+ sid,
1029
+ tone,
1030
+ language,
1031
+ bert,
1032
+ style_vec,
1033
+ noise_scale=0.667,
1034
+ length_scale=1,
1035
+ noise_scale_w=0.8,
1036
+ max_len=None,
1037
+ sdp_ratio=0,
1038
+ y=None,
1039
+ ):
1040
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
1041
+ # g = self.gst(y)
1042
+ if self.n_speakers > 0:
1043
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1044
+ else:
1045
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1046
+ x, m_p, logs_p, x_mask = self.enc_p(
1047
+ x, x_lengths, tone, language, bert, style_vec, g=g
1048
+ )
1049
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1050
+ sdp_ratio
1051
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1052
+ w = torch.exp(logw) * x_mask * length_scale
1053
+ w_ceil = torch.ceil(w)
1054
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1055
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1056
+ x_mask.dtype
1057
+ )
1058
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1059
+ attn = commons.generate_path(w_ceil, attn_mask)
1060
+
1061
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1062
+ 1, 2
1063
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1064
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1065
+ 1, 2
1066
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1067
+
1068
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1069
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1070
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1071
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
monotonic_align/__init__.py CHANGED
@@ -1,16 +1,16 @@
1
- from numpy import zeros, int32, float32
2
- from torch import from_numpy
3
-
4
- from .core import maximum_path_jit
5
-
6
-
7
- def maximum_path(neg_cent, mask):
8
- device = neg_cent.device
9
- dtype = neg_cent.dtype
10
- neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
- path = zeros(neg_cent.shape, dtype=int32)
12
-
13
- t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
- t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
- maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
16
- return from_numpy(path).to(device=device, dtype=dtype)
 
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import maximum_path_jit
5
+
6
+
7
+ def maximum_path(neg_cent, mask):
8
+ device = neg_cent.device
9
+ dtype = neg_cent.dtype
10
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
+ path = zeros(neg_cent.shape, dtype=int32)
12
+
13
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
16
+ return from_numpy(path).to(device=device, dtype=dtype)
monotonic_align/core.py CHANGED
@@ -1,46 +1,46 @@
1
- import numba
2
-
3
-
4
- @numba.jit(
5
- numba.void(
6
- numba.int32[:, :, ::1],
7
- numba.float32[:, :, ::1],
8
- numba.int32[::1],
9
- numba.int32[::1],
10
- ),
11
- nopython=True,
12
- nogil=True,
13
- )
14
- def maximum_path_jit(paths, values, t_ys, t_xs):
15
- b = paths.shape[0]
16
- max_neg_val = -1e9
17
- for i in range(int(b)):
18
- path = paths[i]
19
- value = values[i]
20
- t_y = t_ys[i]
21
- t_x = t_xs[i]
22
-
23
- v_prev = v_cur = 0.0
24
- index = t_x - 1
25
-
26
- for y in range(t_y):
27
- for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
- if x == y:
29
- v_cur = max_neg_val
30
- else:
31
- v_cur = value[y - 1, x]
32
- if x == 0:
33
- if y == 0:
34
- v_prev = 0.0
35
- else:
36
- v_prev = max_neg_val
37
- else:
38
- v_prev = value[y - 1, x - 1]
39
- value[y, x] += max(v_prev, v_cur)
40
-
41
- for y in range(t_y - 1, -1, -1):
42
- path[y, index] = 1
43
- if index != 0 and (
44
- index == y or value[y - 1, index] < value[y - 1, index - 1]
45
- ):
46
- index = index - 1
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(
5
+ numba.void(
6
+ numba.int32[:, :, ::1],
7
+ numba.float32[:, :, ::1],
8
+ numba.int32[::1],
9
+ numba.int32[::1],
10
+ ),
11
+ nopython=True,
12
+ nogil=True,
13
+ )
14
+ def maximum_path_jit(paths, values, t_ys, t_xs):
15
+ b = paths.shape[0]
16
+ max_neg_val = -1e9
17
+ for i in range(int(b)):
18
+ path = paths[i]
19
+ value = values[i]
20
+ t_y = t_ys[i]
21
+ t_x = t_xs[i]
22
+
23
+ v_prev = v_cur = 0.0
24
+ index = t_x - 1
25
+
26
+ for y in range(t_y):
27
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
+ if x == y:
29
+ v_cur = max_neg_val
30
+ else:
31
+ v_cur = value[y - 1, x]
32
+ if x == 0:
33
+ if y == 0:
34
+ v_prev = 0.0
35
+ else:
36
+ v_prev = max_neg_val
37
+ else:
38
+ v_prev = value[y - 1, x - 1]
39
+ value[y, x] += max(v_prev, v_cur)
40
+
41
+ for y in range(t_y - 1, -1, -1):
42
+ path[y, index] = 1
43
+ if index != 0 and (
44
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
45
+ ):
46
+ index = index - 1
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
  cmudict
2
  cn2an
3
- faster-whisper>=0.10.0
4
  g2p_en
5
  GPUtil
6
  gradio
@@ -15,14 +14,14 @@ num2words
15
  numba
16
  numpy
17
  psutil
18
- pyannote.audio>=3.1.0
19
  pyopenjtalk-prebuilt
20
  pypinyin
21
  PyYAML
22
  requests
23
- sentencepiece
24
  safetensors
25
  scipy
 
26
  tensorboard
27
- torch
28
  transformers
 
1
  cmudict
2
  cn2an
 
3
  g2p_en
4
  GPUtil
5
  gradio
 
14
  numba
15
  numpy
16
  psutil
17
+ pyannote.audio
18
  pyopenjtalk-prebuilt
19
  pypinyin
20
  PyYAML
21
  requests
 
22
  safetensors
23
  scipy
24
+ sentencepiece
25
  tensorboard
26
+ torch>=2.1,<2.2 # For users without GPU or colab
27
  transformers
style_gen.py CHANGED
@@ -1,6 +1,5 @@
1
  import argparse
2
- import concurrent.futures
3
- import sys
4
  import warnings
5
 
6
  import numpy as np
@@ -8,6 +7,8 @@ import torch
8
  from tqdm import tqdm
9
 
10
  import utils
 
 
11
  from config import config
12
 
13
  warnings.filterwarnings("ignore", category=UserWarning)
@@ -19,14 +20,44 @@ device = torch.device(config.style_gen_config.device)
19
  inference.to(device)
20
 
21
 
22
- def extract_style_vector(wav_path):
 
 
 
 
 
 
 
23
  return inference(wav_path)
24
 
25
 
26
  def save_style_vector(wav_path):
27
- style_vec = extract_style_vector(wav_path)
28
- # `test.wav` -> `test.wav.npy`
29
- np.save(f"{wav_path}.npy", style_vec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  if __name__ == "__main__":
@@ -45,22 +76,53 @@ if __name__ == "__main__":
45
 
46
  device = config.style_gen_config.device
47
 
48
- lines = []
49
  with open(hps.data.training_files, encoding="utf-8") as f:
50
- lines.extend(f.readlines())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
52
  with open(hps.data.validation_files, encoding="utf-8") as f:
53
- lines.extend(f.readlines())
54
-
55
- wavnames = [line.split("|")[0] for line in lines]
56
 
57
- with concurrent.futures.ThreadPoolExecutor(max_workers=num_processes) as executor:
58
- list(
59
  tqdm(
60
- executor.map(save_style_vector, wavnames),
61
- total=len(wavnames),
62
- file=sys.stdout,
63
  )
64
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- print(f"Finished generating style vectors! total: {len(wavnames)} npy files.")
 
1
  import argparse
2
+ from concurrent.futures import ThreadPoolExecutor
 
3
  import warnings
4
 
5
  import numpy as np
 
7
  from tqdm import tqdm
8
 
9
  import utils
10
+ from common.log import logger
11
+ from common.stdout_wrapper import SAFE_STDOUT
12
  from config import config
13
 
14
  warnings.filterwarnings("ignore", category=UserWarning)
 
20
  inference.to(device)
21
 
22
 
23
+ class NaNValueError(ValueError):
24
+ """カスタム例外クラス。NaN値が見つかった場合に使用されます。"""
25
+
26
+ pass
27
+
28
+
29
+ # 推論時にインポートするために短いが関数を書く
30
+ def get_style_vector(wav_path):
31
  return inference(wav_path)
32
 
33
 
34
  def save_style_vector(wav_path):
35
+ try:
36
+ style_vec = get_style_vector(wav_path)
37
+ except Exception as e:
38
+ print("\n")
39
+ logger.error(f"Error occurred with file: {wav_path}, Details:\n{e}\n")
40
+ raise
41
+ # 値にNaNが含まれていると悪影響なのでチェックする
42
+ if np.isnan(style_vec).any():
43
+ print("\n")
44
+ logger.warning(f"NaN value found in style vector: {wav_path}")
45
+ raise NaNValueError(f"NaN value found in style vector: {wav_path}")
46
+ np.save(f"{wav_path}.npy", style_vec) # `test.wav` -> `test.wav.npy`
47
+
48
+
49
+ def process_line(line):
50
+ wavname = line.split("|")[0]
51
+ try:
52
+ save_style_vector(wavname)
53
+ return line, None
54
+ except NaNValueError:
55
+ return line, "nan_error"
56
+
57
+
58
+ def save_average_style_vector(style_vectors, filename="style_vectors.npy"):
59
+ average_vector = np.mean(style_vectors, axis=0)
60
+ np.save(filename, average_vector)
61
 
62
 
63
  if __name__ == "__main__":
 
76
 
77
  device = config.style_gen_config.device
78
 
79
+ training_lines = []
80
  with open(hps.data.training_files, encoding="utf-8") as f:
81
+ training_lines.extend(f.readlines())
82
+ with ThreadPoolExecutor(max_workers=num_processes) as executor:
83
+ training_results = list(
84
+ tqdm(
85
+ executor.map(process_line, training_lines),
86
+ total=len(training_lines),
87
+ file=SAFE_STDOUT,
88
+ )
89
+ )
90
+ ok_training_lines = [line for line, error in training_results if error is None]
91
+ nan_training_lines = [
92
+ line for line, error in training_results if error == "nan_error"
93
+ ]
94
+ if nan_training_lines:
95
+ nan_files = [line.split("|")[0] for line in nan_training_lines]
96
+ logger.warning(
97
+ f"Found NaN value in {len(nan_training_lines)} files: {nan_files}, so they will be deleted from training data."
98
+ )
99
 
100
+ val_lines = []
101
  with open(hps.data.validation_files, encoding="utf-8") as f:
102
+ val_lines.extend(f.readlines())
 
 
103
 
104
+ with ThreadPoolExecutor(max_workers=num_processes) as executor:
105
+ val_results = list(
106
  tqdm(
107
+ executor.map(process_line, val_lines),
108
+ total=len(val_lines),
109
+ file=SAFE_STDOUT,
110
  )
111
  )
112
+ ok_val_lines = [line for line, error in val_results if error is None]
113
+ nan_val_lines = [line for line, error in val_results if error == "nan_error"]
114
+ if nan_val_lines:
115
+ nan_files = [line.split("|")[0] for line in nan_val_lines]
116
+ logger.warning(
117
+ f"Found NaN value in {len(nan_val_lines)} files: {nan_files}, so they will be deleted from validation data."
118
+ )
119
+
120
+ with open(hps.data.training_files, "w", encoding="utf-8") as f:
121
+ f.writelines(ok_training_lines)
122
+
123
+ with open(hps.data.validation_files, "w", encoding="utf-8") as f:
124
+ f.writelines(ok_val_lines)
125
+
126
+ ok_num = len(ok_training_lines) + len(ok_val_lines)
127
 
128
+ logger.info(f"Finished generating style vectors! total: {ok_num} npy files.")
text/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text.symbols import *
2
+
3
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
4
+
5
+
6
+ def cleaned_text_to_sequence(cleaned_text, tones, language):
7
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
8
+ Args:
9
+ text: string to convert to a sequence
10
+ Returns:
11
+ List of integers corresponding to the symbols in the text
12
+ """
13
+ phones = [_symbol_to_id[symbol] for symbol in cleaned_text]
14
+ tone_start = language_tone_start_map[language]
15
+ tones = [i + tone_start for i in tones]
16
+ lang_id = language_id_map[language]
17
+ lang_ids = [lang_id for i in phones]
18
+ return phones, tones, lang_ids
19
+
20
+
21
+ def get_bert(
22
+ norm_text, word2ph, language, device, assist_text=None, assist_text_weight=0.7
23
+ ):
24
+ from .chinese_bert import get_bert_feature as zh_bert
25
+ from .english_bert_mock import get_bert_feature as en_bert
26
+ from .japanese_bert import get_bert_feature as jp_bert
27
+
28
+ lang_bert_func_map = {"ZH": zh_bert, "EN": en_bert, "JP": jp_bert}
29
+ bert = lang_bert_func_map[language](
30
+ norm_text, word2ph, device, assist_text, assist_text_weight
31
+ )
32
+ return bert
text/chinese.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import cn2an
5
+ from pypinyin import lazy_pinyin, Style
6
+
7
+ from text.symbols import punctuation
8
+ from text.tone_sandhi import ToneSandhi
9
+
10
+ current_file_path = os.path.dirname(__file__)
11
+ pinyin_to_symbol_map = {
12
+ line.split("\t")[0]: line.strip().split("\t")[1]
13
+ for line in open(os.path.join(current_file_path, "opencpop-strict.txt")).readlines()
14
+ }
15
+
16
+ import jieba.posseg as psg
17
+
18
+
19
+ rep_map = {
20
+ ":": ",",
21
+ ";": ",",
22
+ ",": ",",
23
+ "。": ".",
24
+ "!": "!",
25
+ "?": "?",
26
+ "\n": ".",
27
+ "·": ",",
28
+ "、": ",",
29
+ "...": "…",
30
+ "$": ".",
31
+ "“": "'",
32
+ "”": "'",
33
+ '"': "'",
34
+ "‘": "'",
35
+ "’": "'",
36
+ "(": "'",
37
+ ")": "'",
38
+ "(": "'",
39
+ ")": "'",
40
+ "《": "'",
41
+ "》": "'",
42
+ "【": "'",
43
+ "】": "'",
44
+ "[": "'",
45
+ "]": "'",
46
+ "—": "-",
47
+ "~": "-",
48
+ "~": "-",
49
+ "「": "'",
50
+ "」": "'",
51
+ }
52
+
53
+ tone_modifier = ToneSandhi()
54
+
55
+
56
+ def replace_punctuation(text):
57
+ text = text.replace("嗯", "恩").replace("呣", "母")
58
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
59
+
60
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
61
+
62
+ replaced_text = re.sub(
63
+ r"[^\u4e00-\u9fa5" + "".join(punctuation) + r"]+", "", replaced_text
64
+ )
65
+
66
+ return replaced_text
67
+
68
+
69
+ def g2p(text):
70
+ pattern = r"(?<=[{0}])\s*".format("".join(punctuation))
71
+ sentences = [i for i in re.split(pattern, text) if i.strip() != ""]
72
+ phones, tones, word2ph = _g2p(sentences)
73
+ assert sum(word2ph) == len(phones)
74
+ assert len(word2ph) == len(text) # Sometimes it will crash,you can add a try-catch.
75
+ phones = ["_"] + phones + ["_"]
76
+ tones = [0] + tones + [0]
77
+ word2ph = [1] + word2ph + [1]
78
+ return phones, tones, word2ph
79
+
80
+
81
+ def _get_initials_finals(word):
82
+ initials = []
83
+ finals = []
84
+ orig_initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
85
+ orig_finals = lazy_pinyin(
86
+ word, neutral_tone_with_five=True, style=Style.FINALS_TONE3
87
+ )
88
+ for c, v in zip(orig_initials, orig_finals):
89
+ initials.append(c)
90
+ finals.append(v)
91
+ return initials, finals
92
+
93
+
94
+ def _g2p(segments):
95
+ phones_list = []
96
+ tones_list = []
97
+ word2ph = []
98
+ for seg in segments:
99
+ # Replace all English words in the sentence
100
+ seg = re.sub("[a-zA-Z]+", "", seg)
101
+ seg_cut = psg.lcut(seg)
102
+ initials = []
103
+ finals = []
104
+ seg_cut = tone_modifier.pre_merge_for_modify(seg_cut)
105
+ for word, pos in seg_cut:
106
+ if pos == "eng":
107
+ continue
108
+ sub_initials, sub_finals = _get_initials_finals(word)
109
+ sub_finals = tone_modifier.modified_tone(word, pos, sub_finals)
110
+ initials.append(sub_initials)
111
+ finals.append(sub_finals)
112
+
113
+ # assert len(sub_initials) == len(sub_finals) == len(word)
114
+ initials = sum(initials, [])
115
+ finals = sum(finals, [])
116
+ #
117
+ for c, v in zip(initials, finals):
118
+ raw_pinyin = c + v
119
+ # NOTE: post process for pypinyin outputs
120
+ # we discriminate i, ii and iii
121
+ if c == v:
122
+ assert c in punctuation
123
+ phone = [c]
124
+ tone = "0"
125
+ word2ph.append(1)
126
+ else:
127
+ v_without_tone = v[:-1]
128
+ tone = v[-1]
129
+
130
+ pinyin = c + v_without_tone
131
+ assert tone in "12345"
132
+
133
+ if c:
134
+ # 多音节
135
+ v_rep_map = {
136
+ "uei": "ui",
137
+ "iou": "iu",
138
+ "uen": "un",
139
+ }
140
+ if v_without_tone in v_rep_map.keys():
141
+ pinyin = c + v_rep_map[v_without_tone]
142
+ else:
143
+ # 单音节
144
+ pinyin_rep_map = {
145
+ "ing": "ying",
146
+ "i": "yi",
147
+ "in": "yin",
148
+ "u": "wu",
149
+ }
150
+ if pinyin in pinyin_rep_map.keys():
151
+ pinyin = pinyin_rep_map[pinyin]
152
+ else:
153
+ single_rep_map = {
154
+ "v": "yu",
155
+ "e": "e",
156
+ "i": "y",
157
+ "u": "w",
158
+ }
159
+ if pinyin[0] in single_rep_map.keys():
160
+ pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
161
+
162
+ assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin)
163
+ phone = pinyin_to_symbol_map[pinyin].split(" ")
164
+ word2ph.append(len(phone))
165
+
166
+ phones_list += phone
167
+ tones_list += [int(tone)] * len(phone)
168
+ return phones_list, tones_list, word2ph
169
+
170
+
171
+ def text_normalize(text):
172
+ numbers = re.findall(r"\d+(?:\.?\d+)?", text)
173
+ for number in numbers:
174
+ text = text.replace(number, cn2an.an2cn(number), 1)
175
+ text = replace_punctuation(text)
176
+ return text
177
+
178
+
179
+ def get_bert_feature(text, word2ph):
180
+ from text import chinese_bert
181
+
182
+ return chinese_bert.get_bert_feature(text, word2ph)
183
+
184
+
185
+ if __name__ == "__main__":
186
+ from text.chinese_bert import get_bert_feature
187
+
188
+ text = "啊!但是《原神》是由,米哈\游自主, [研发]的一款全.新开放世界.冒险游戏"
189
+ text = text_normalize(text)
190
+ print(text)
191
+ phones, tones, word2ph = g2p(text)
192
+ bert = get_bert_feature(text, word2ph)
193
+
194
+ print(phones, tones, word2ph, bert.shape)
195
+
196
+
197
+ # # 示例用法
198
+ # text = "这是一个示例文本:,你好!这是一个测试...."
199
+ # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试
text/chinese_bert.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
5
+
6
+ from config import config
7
+
8
+ LOCAL_PATH = "./bert/chinese-roberta-wwm-ext-large"
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
11
+
12
+ models = dict()
13
+
14
+
15
+ def get_bert_feature(
16
+ text,
17
+ word2ph,
18
+ device=config.bert_gen_config.device,
19
+ assist_text=None,
20
+ assist_text_weight=0.7,
21
+ ):
22
+ if (
23
+ sys.platform == "darwin"
24
+ and torch.backends.mps.is_available()
25
+ and device == "cpu"
26
+ ):
27
+ device = "mps"
28
+ if not device:
29
+ device = "cuda"
30
+ if device == "cuda" and not torch.cuda.is_available():
31
+ device = "cpu"
32
+ if device not in models.keys():
33
+ models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
34
+ with torch.no_grad():
35
+ inputs = tokenizer(text, return_tensors="pt")
36
+ for i in inputs:
37
+ inputs[i] = inputs[i].to(device)
38
+ res = models[device](**inputs, output_hidden_states=True)
39
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
40
+ if assist_text:
41
+ style_inputs = tokenizer(assist_text, return_tensors="pt")
42
+ for i in style_inputs:
43
+ style_inputs[i] = style_inputs[i].to(device)
44
+ style_res = models[device](**style_inputs, output_hidden_states=True)
45
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
46
+ style_res_mean = style_res.mean(0)
47
+ assert len(word2ph) == len(text) + 2
48
+ word2phone = word2ph
49
+ phone_level_feature = []
50
+ for i in range(len(word2phone)):
51
+ if assist_text:
52
+ repeat_feature = (
53
+ res[i].repeat(word2phone[i], 1) * (1 - assist_text_weight)
54
+ + style_res_mean.repeat(word2phone[i], 1) * assist_text_weight
55
+ )
56
+ else:
57
+ repeat_feature = res[i].repeat(word2phone[i], 1)
58
+ phone_level_feature.append(repeat_feature)
59
+
60
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
61
+
62
+ return phone_level_feature.T
63
+
64
+
65
+ if __name__ == "__main__":
66
+ word_level_feature = torch.rand(38, 1024) # 12个词,每个词1024维特征
67
+ word2phone = [
68
+ 1,
69
+ 2,
70
+ 1,
71
+ 2,
72
+ 2,
73
+ 1,
74
+ 2,
75
+ 2,
76
+ 1,
77
+ 2,
78
+ 2,
79
+ 1,
80
+ 2,
81
+ 2,
82
+ 2,
83
+ 2,
84
+ 2,
85
+ 1,
86
+ 1,
87
+ 2,
88
+ 2,
89
+ 1,
90
+ 2,
91
+ 2,
92
+ 2,
93
+ 2,
94
+ 1,
95
+ 2,
96
+ 2,
97
+ 2,
98
+ 2,
99
+ 2,
100
+ 1,
101
+ 2,
102
+ 2,
103
+ 2,
104
+ 2,
105
+ 1,
106
+ ]
107
+
108
+ # 计算总帧数
109
+ total_frames = sum(word2phone)
110
+ print(word_level_feature.shape)
111
+ print(word2phone)
112
+ phone_level_feature = []
113
+ for i in range(len(word2phone)):
114
+ print(word_level_feature[i].shape)
115
+
116
+ # 对每个词重复word2phone[i]次
117
+ repeat_feature = word_level_feature[i].repeat(word2phone[i], 1)
118
+ phone_level_feature.append(repeat_feature)
119
+
120
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
121
+ print(phone_level_feature.shape) # torch.Size([36, 1024])
text/cleaner.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from text import chinese, japanese, english, cleaned_text_to_sequence
2
+
3
+
4
+ language_module_map = {"ZH": chinese, "JP": japanese, "EN": english}
5
+
6
+
7
+ def clean_text(text, language, use_jp_extra=True):
8
+ language_module = language_module_map[language]
9
+ norm_text = language_module.text_normalize(text)
10
+ if language == "JP":
11
+ phones, tones, word2ph = language_module.g2p(norm_text, use_jp_extra)
12
+ else:
13
+ phones, tones, word2ph = language_module.g2p(norm_text)
14
+ return norm_text, phones, tones, word2ph
15
+
16
+
17
+ def clean_text_bert(text, language):
18
+ language_module = language_module_map[language]
19
+ norm_text = language_module.text_normalize(text)
20
+ phones, tones, word2ph = language_module.g2p(norm_text)
21
+ bert = language_module.get_bert_feature(norm_text, word2ph)
22
+ return phones, tones, bert
23
+
24
+
25
+ def text_to_sequence(text, language):
26
+ norm_text, phones, tones, word2ph = clean_text(text, language)
27
+ return cleaned_text_to_sequence(phones, tones, language)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ pass
text/cmudict.rep ADDED
The diff for this file is too large to render. See raw diff
 
text/cmudict_cache.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9b21b20325471934ba92f2e4a5976989e7d920caa32e7a286eacb027d197949
3
+ size 6212655
text/english.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+ import re
4
+ from g2p_en import G2p
5
+ from transformers import DebertaV2Tokenizer
6
+
7
+ from text import symbols
8
+ from text.symbols import punctuation
9
+
10
+ current_file_path = os.path.dirname(__file__)
11
+ CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
12
+ CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
13
+ _g2p = G2p()
14
+ LOCAL_PATH = "./bert/deberta-v3-large"
15
+ tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
16
+
17
+ arpa = {
18
+ "AH0",
19
+ "S",
20
+ "AH1",
21
+ "EY2",
22
+ "AE2",
23
+ "EH0",
24
+ "OW2",
25
+ "UH0",
26
+ "NG",
27
+ "B",
28
+ "G",
29
+ "AY0",
30
+ "M",
31
+ "AA0",
32
+ "F",
33
+ "AO0",
34
+ "ER2",
35
+ "UH1",
36
+ "IY1",
37
+ "AH2",
38
+ "DH",
39
+ "IY0",
40
+ "EY1",
41
+ "IH0",
42
+ "K",
43
+ "N",
44
+ "W",
45
+ "IY2",
46
+ "T",
47
+ "AA1",
48
+ "ER1",
49
+ "EH2",
50
+ "OY0",
51
+ "UH2",
52
+ "UW1",
53
+ "Z",
54
+ "AW2",
55
+ "AW1",
56
+ "V",
57
+ "UW2",
58
+ "AA2",
59
+ "ER",
60
+ "AW0",
61
+ "UW0",
62
+ "R",
63
+ "OW1",
64
+ "EH1",
65
+ "ZH",
66
+ "AE0",
67
+ "IH2",
68
+ "IH",
69
+ "Y",
70
+ "JH",
71
+ "P",
72
+ "AY1",
73
+ "EY0",
74
+ "OY2",
75
+ "TH",
76
+ "HH",
77
+ "D",
78
+ "ER0",
79
+ "CH",
80
+ "AO1",
81
+ "AE1",
82
+ "AO2",
83
+ "OY1",
84
+ "AY2",
85
+ "IH1",
86
+ "OW0",
87
+ "L",
88
+ "SH",
89
+ }
90
+
91
+
92
+ def post_replace_ph(ph):
93
+ rep_map = {
94
+ ":": ",",
95
+ ";": ",",
96
+ ",": ",",
97
+ "。": ".",
98
+ "!": "!",
99
+ "?": "?",
100
+ "\n": ".",
101
+ "·": ",",
102
+ "、": ",",
103
+ "…": "...",
104
+ "···": "...",
105
+ "・・・": "...",
106
+ "v": "V",
107
+ }
108
+ if ph in rep_map.keys():
109
+ ph = rep_map[ph]
110
+ if ph in symbols:
111
+ return ph
112
+ if ph not in symbols:
113
+ ph = "UNK"
114
+ return ph
115
+
116
+
117
+ rep_map = {
118
+ ":": ",",
119
+ ";": ",",
120
+ ",": ",",
121
+ "。": ".",
122
+ "!": "!",
123
+ "?": "?",
124
+ "\n": ".",
125
+ ".": ".",
126
+ "…": "...",
127
+ "···": "...",
128
+ "・・・": "...",
129
+ "·": ",",
130
+ "・": ",",
131
+ "、": ",",
132
+ "$": ".",
133
+ "“": "'",
134
+ "”": "'",
135
+ '"': "'",
136
+ "‘": "'",
137
+ "’": "'",
138
+ "(": "'",
139
+ ")": "'",
140
+ "(": "'",
141
+ ")": "'",
142
+ "《": "'",
143
+ "》": "'",
144
+ "【": "'",
145
+ "】": "'",
146
+ "[": "'",
147
+ "]": "'",
148
+ "—": "-",
149
+ "−": "-",
150
+ "~": "-",
151
+ "~": "-",
152
+ "「": "'",
153
+ "」": "'",
154
+ }
155
+
156
+
157
+ def replace_punctuation(text):
158
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
159
+
160
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
161
+
162
+ # replaced_text = re.sub(
163
+ # r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
164
+ # + "".join(punctuation)
165
+ # + r"]+",
166
+ # "",
167
+ # replaced_text,
168
+ # )
169
+
170
+ return replaced_text
171
+
172
+
173
+ def read_dict():
174
+ g2p_dict = {}
175
+ start_line = 49
176
+ with open(CMU_DICT_PATH) as f:
177
+ line = f.readline()
178
+ line_index = 1
179
+ while line:
180
+ if line_index >= start_line:
181
+ line = line.strip()
182
+ word_split = line.split(" ")
183
+ word = word_split[0]
184
+
185
+ syllable_split = word_split[1].split(" - ")
186
+ g2p_dict[word] = []
187
+ for syllable in syllable_split:
188
+ phone_split = syllable.split(" ")
189
+ g2p_dict[word].append(phone_split)
190
+
191
+ line_index = line_index + 1
192
+ line = f.readline()
193
+
194
+ return g2p_dict
195
+
196
+
197
+ def cache_dict(g2p_dict, file_path):
198
+ with open(file_path, "wb") as pickle_file:
199
+ pickle.dump(g2p_dict, pickle_file)
200
+
201
+
202
+ def get_dict():
203
+ if os.path.exists(CACHE_PATH):
204
+ with open(CACHE_PATH, "rb") as pickle_file:
205
+ g2p_dict = pickle.load(pickle_file)
206
+ else:
207
+ g2p_dict = read_dict()
208
+ cache_dict(g2p_dict, CACHE_PATH)
209
+
210
+ return g2p_dict
211
+
212
+
213
+ eng_dict = get_dict()
214
+
215
+
216
+ def refine_ph(phn):
217
+ tone = 0
218
+ if re.search(r"\d$", phn):
219
+ tone = int(phn[-1]) + 1
220
+ phn = phn[:-1]
221
+ else:
222
+ tone = 3
223
+ return phn.lower(), tone
224
+
225
+
226
+ def refine_syllables(syllables):
227
+ tones = []
228
+ phonemes = []
229
+ for phn_list in syllables:
230
+ for i in range(len(phn_list)):
231
+ phn = phn_list[i]
232
+ phn, tone = refine_ph(phn)
233
+ phonemes.append(phn)
234
+ tones.append(tone)
235
+ return phonemes, tones
236
+
237
+
238
+ import re
239
+ import inflect
240
+
241
+ _inflect = inflect.engine()
242
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
243
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
244
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
245
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
246
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
247
+ _number_re = re.compile(r"[0-9]+")
248
+
249
+ # List of (regular expression, replacement) pairs for abbreviations:
250
+ _abbreviations = [
251
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
252
+ for x in [
253
+ ("mrs", "misess"),
254
+ ("mr", "mister"),
255
+ ("dr", "doctor"),
256
+ ("st", "saint"),
257
+ ("co", "company"),
258
+ ("jr", "junior"),
259
+ ("maj", "major"),
260
+ ("gen", "general"),
261
+ ("drs", "doctors"),
262
+ ("rev", "reverend"),
263
+ ("lt", "lieutenant"),
264
+ ("hon", "honorable"),
265
+ ("sgt", "sergeant"),
266
+ ("capt", "captain"),
267
+ ("esq", "esquire"),
268
+ ("ltd", "limited"),
269
+ ("col", "colonel"),
270
+ ("ft", "fort"),
271
+ ]
272
+ ]
273
+
274
+
275
+ # List of (ipa, lazy ipa) pairs:
276
+ _lazy_ipa = [
277
+ (re.compile("%s" % x[0]), x[1])
278
+ for x in [
279
+ ("r", "ɹ"),
280
+ ("æ", "e"),
281
+ ("ɑ", "a"),
282
+ ("ɔ", "o"),
283
+ ("ð", "z"),
284
+ ("θ", "s"),
285
+ ("ɛ", "e"),
286
+ ("ɪ", "i"),
287
+ ("ʊ", "u"),
288
+ ("ʒ", "ʥ"),
289
+ ("ʤ", "ʥ"),
290
+ ("ˈ", "↓"),
291
+ ]
292
+ ]
293
+
294
+ # List of (ipa, lazy ipa2) pairs:
295
+ _lazy_ipa2 = [
296
+ (re.compile("%s" % x[0]), x[1])
297
+ for x in [
298
+ ("r", "ɹ"),
299
+ ("ð", "z"),
300
+ ("θ", "s"),
301
+ ("ʒ", "ʑ"),
302
+ ("ʤ", "dʑ"),
303
+ ("ˈ", "↓"),
304
+ ]
305
+ ]
306
+
307
+ # List of (ipa, ipa2) pairs
308
+ _ipa_to_ipa2 = [
309
+ (re.compile("%s" % x[0]), x[1]) for x in [("r", "ɹ"), ("ʤ", "dʒ"), ("ʧ", "tʃ")]
310
+ ]
311
+
312
+
313
+ def _expand_dollars(m):
314
+ match = m.group(1)
315
+ parts = match.split(".")
316
+ if len(parts) > 2:
317
+ return match + " dollars" # Unexpected format
318
+ dollars = int(parts[0]) if parts[0] else 0
319
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
320
+ if dollars and cents:
321
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
322
+ cent_unit = "cent" if cents == 1 else "cents"
323
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
324
+ elif dollars:
325
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
326
+ return "%s %s" % (dollars, dollar_unit)
327
+ elif cents:
328
+ cent_unit = "cent" if cents == 1 else "cents"
329
+ return "%s %s" % (cents, cent_unit)
330
+ else:
331
+ return "zero dollars"
332
+
333
+
334
+ def _remove_commas(m):
335
+ return m.group(1).replace(",", "")
336
+
337
+
338
+ def _expand_ordinal(m):
339
+ return _inflect.number_to_words(m.group(0))
340
+
341
+
342
+ def _expand_number(m):
343
+ num = int(m.group(0))
344
+ if num > 1000 and num < 3000:
345
+ if num == 2000:
346
+ return "two thousand"
347
+ elif num > 2000 and num < 2010:
348
+ return "two thousand " + _inflect.number_to_words(num % 100)
349
+ elif num % 100 == 0:
350
+ return _inflect.number_to_words(num // 100) + " hundred"
351
+ else:
352
+ return _inflect.number_to_words(
353
+ num, andword="", zero="oh", group=2
354
+ ).replace(", ", " ")
355
+ else:
356
+ return _inflect.number_to_words(num, andword="")
357
+
358
+
359
+ def _expand_decimal_point(m):
360
+ return m.group(1).replace(".", " point ")
361
+
362
+
363
+ def normalize_numbers(text):
364
+ text = re.sub(_comma_number_re, _remove_commas, text)
365
+ text = re.sub(_pounds_re, r"\1 pounds", text)
366
+ text = re.sub(_dollars_re, _expand_dollars, text)
367
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
368
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
369
+ text = re.sub(_number_re, _expand_number, text)
370
+ return text
371
+
372
+
373
+ def text_normalize(text):
374
+ text = normalize_numbers(text)
375
+ text = replace_punctuation(text)
376
+ text = re.sub(r"([,;.\?\!])([\w])", r"\1 \2", text)
377
+ return text
378
+
379
+
380
+ def distribute_phone(n_phone, n_word):
381
+ phones_per_word = [0] * n_word
382
+ for task in range(n_phone):
383
+ min_tasks = min(phones_per_word)
384
+ min_index = phones_per_word.index(min_tasks)
385
+ phones_per_word[min_index] += 1
386
+ return phones_per_word
387
+
388
+
389
+ def sep_text(text):
390
+ words = re.split(r"([,;.\?\!\s+])", text)
391
+ words = [word for word in words if word.strip() != ""]
392
+ return words
393
+
394
+
395
+ def text_to_words(text):
396
+ tokens = tokenizer.tokenize(text)
397
+ words = []
398
+ for idx, t in enumerate(tokens):
399
+ if t.startswith("▁"):
400
+ words.append([t[1:]])
401
+ else:
402
+ if t in punctuation:
403
+ if idx == len(tokens) - 1:
404
+ words.append([f"{t}"])
405
+ else:
406
+ if (
407
+ not tokens[idx + 1].startswith("▁")
408
+ and tokens[idx + 1] not in punctuation
409
+ ):
410
+ if idx == 0:
411
+ words.append([])
412
+ words[-1].append(f"{t}")
413
+ else:
414
+ words.append([f"{t}"])
415
+ else:
416
+ if idx == 0:
417
+ words.append([])
418
+ words[-1].append(f"{t}")
419
+ return words
420
+
421
+
422
+ def g2p(text):
423
+ phones = []
424
+ tones = []
425
+ phone_len = []
426
+ # words = sep_text(text)
427
+ # tokens = [tokenizer.tokenize(i) for i in words]
428
+ words = text_to_words(text)
429
+
430
+ for word in words:
431
+ temp_phones, temp_tones = [], []
432
+ if len(word) > 1:
433
+ if "'" in word:
434
+ word = ["".join(word)]
435
+ for w in word:
436
+ if w in punctuation:
437
+ temp_phones.append(w)
438
+ temp_tones.append(0)
439
+ continue
440
+ if w.upper() in eng_dict:
441
+ phns, tns = refine_syllables(eng_dict[w.upper()])
442
+ temp_phones += [post_replace_ph(i) for i in phns]
443
+ temp_tones += tns
444
+ # w2ph.append(len(phns))
445
+ else:
446
+ phone_list = list(filter(lambda p: p != " ", _g2p(w)))
447
+ phns = []
448
+ tns = []
449
+ for ph in phone_list:
450
+ if ph in arpa:
451
+ ph, tn = refine_ph(ph)
452
+ phns.append(ph)
453
+ tns.append(tn)
454
+ else:
455
+ phns.append(ph)
456
+ tns.append(0)
457
+ temp_phones += [post_replace_ph(i) for i in phns]
458
+ temp_tones += tns
459
+ phones += temp_phones
460
+ tones += temp_tones
461
+ phone_len.append(len(temp_phones))
462
+ # phones = [post_replace_ph(i) for i in phones]
463
+
464
+ word2ph = []
465
+ for token, pl in zip(words, phone_len):
466
+ word_len = len(token)
467
+
468
+ aaa = distribute_phone(pl, word_len)
469
+ word2ph += aaa
470
+
471
+ phones = ["_"] + phones + ["_"]
472
+ tones = [0] + tones + [0]
473
+ word2ph = [1] + word2ph + [1]
474
+ assert len(phones) == len(tones), text
475
+ assert len(phones) == sum(word2ph), text
476
+
477
+ return phones, tones, word2ph
478
+
479
+
480
+ def get_bert_feature(text, word2ph):
481
+ from text import english_bert_mock
482
+
483
+ return english_bert_mock.get_bert_feature(text, word2ph)
484
+
485
+
486
+ if __name__ == "__main__":
487
+ # print(get_dict())
488
+ # print(eng_word_to_phoneme("hello"))
489
+ print(g2p("In this paper, we propose 1 DSPGAN, a GAN-based universal vocoder."))
490
+ # all_phones = set()
491
+ # for k, syllables in eng_dict.items():
492
+ # for group in syllables:
493
+ # for ph in group:
494
+ # all_phones.add(ph)
495
+ # print(all_phones)
text/english_bert_mock.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import DebertaV2Model, DebertaV2Tokenizer
5
+
6
+ from config import config
7
+
8
+
9
+ LOCAL_PATH = "./bert/deberta-v3-large"
10
+
11
+ tokenizer = DebertaV2Tokenizer.from_pretrained(LOCAL_PATH)
12
+
13
+ models = dict()
14
+
15
+
16
+ def get_bert_feature(
17
+ text,
18
+ word2ph,
19
+ device=config.bert_gen_config.device,
20
+ assist_text=None,
21
+ assist_text_weight=0.7,
22
+ ):
23
+ if (
24
+ sys.platform == "darwin"
25
+ and torch.backends.mps.is_available()
26
+ and device == "cpu"
27
+ ):
28
+ device = "mps"
29
+ if not device:
30
+ device = "cuda"
31
+ if device == "cuda" and not torch.cuda.is_available():
32
+ device = "cpu"
33
+ if device not in models.keys():
34
+ models[device] = DebertaV2Model.from_pretrained(LOCAL_PATH).to(device)
35
+ with torch.no_grad():
36
+ inputs = tokenizer(text, return_tensors="pt")
37
+ for i in inputs:
38
+ inputs[i] = inputs[i].to(device)
39
+ res = models[device](**inputs, output_hidden_states=True)
40
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
41
+ if assist_text:
42
+ style_inputs = tokenizer(assist_text, return_tensors="pt")
43
+ for i in style_inputs:
44
+ style_inputs[i] = style_inputs[i].to(device)
45
+ style_res = models[device](**style_inputs, output_hidden_states=True)
46
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
47
+ style_res_mean = style_res.mean(0)
48
+ assert len(word2ph) == res.shape[0], (text, res.shape[0], len(word2ph))
49
+ word2phone = word2ph
50
+ phone_level_feature = []
51
+ for i in range(len(word2phone)):
52
+ if assist_text:
53
+ repeat_feature = (
54
+ res[i].repeat(word2phone[i], 1) * (1 - assist_text_weight)
55
+ + style_res_mean.repeat(word2phone[i], 1) * assist_text_weight
56
+ )
57
+ else:
58
+ repeat_feature = res[i].repeat(word2phone[i], 1)
59
+ phone_level_feature.append(repeat_feature)
60
+
61
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
62
+
63
+ return phone_level_feature.T
text/japanese.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert Japanese text to phonemes which is
2
+ # compatible with Julius https://github.com/julius-speech/segmentation-kit
3
+ import re
4
+ import unicodedata
5
+
6
+ import pyopenjtalk
7
+ from num2words import num2words
8
+ from transformers import AutoTokenizer
9
+
10
+ from common.log import logger
11
+ from text import punctuation
12
+ from text.japanese_mora_list import (
13
+ mora_kata_to_mora_phonemes,
14
+ mora_phonemes_to_mora_kata,
15
+ )
16
+
17
+ # 子音の集合
18
+ COSONANTS = set(
19
+ [
20
+ cosonant
21
+ for cosonant, _ in mora_kata_to_mora_phonemes.values()
22
+ if cosonant is not None
23
+ ]
24
+ )
25
+
26
+ # 母音の集合、便宜上「ん」を含める
27
+ VOWELS = {"a", "i", "u", "e", "o", "N"}
28
+
29
+
30
+ # 正規化で記号を変換するための辞書
31
+ rep_map = {
32
+ ":": ",",
33
+ ";": ",",
34
+ ",": ",",
35
+ "。": ".",
36
+ "!": "!",
37
+ "?": "?",
38
+ "\n": ".",
39
+ ".": ".",
40
+ "…": "...",
41
+ "···": "...",
42
+ "・・・": "...",
43
+ "·": ",",
44
+ "・": ",",
45
+ "、": ",",
46
+ "$": ".",
47
+ "“": "'",
48
+ "”": "'",
49
+ '"': "'",
50
+ "‘": "'",
51
+ "’": "'",
52
+ "(": "'",
53
+ ")": "'",
54
+ "(": "'",
55
+ ")": "'",
56
+ "《": "'",
57
+ "》": "'",
58
+ "【": "'",
59
+ "】": "'",
60
+ "[": "'",
61
+ "]": "'",
62
+ "—": "-",
63
+ "−": "-",
64
+ # "~": "-", # これは長音記号「ー」として扱うよう変更
65
+ # "~": "-", # これも長音記号「ー」として扱うよう変更
66
+ "「": "'",
67
+ "」": "'",
68
+ }
69
+
70
+
71
+ def text_normalize(text):
72
+ """
73
+ 日本語のテキストを正規化する。
74
+ 結果は、ちょうど次の文字のみからなる:
75
+ - ひらがな
76
+ - カタカナ(全角長音記号「ー」が入る!)
77
+ - 漢字
78
+ - 半角アルファベット(大文字と小文字)
79
+ - ギリシャ文字
80
+ - `.` (句点`。`や`…`の一部や改行等)
81
+ - `,` (読点`、`や`:`等)
82
+ - `?` (疑問符`?`)
83
+ - `!` (感嘆符`!`)
84
+ - `'` (`「`や`」`等)
85
+ - `-` (`―`(ダッシュ、長音記号ではない)や`-`等)
86
+
87
+ 注意点:
88
+ - 三点リーダー`…`は`...`に変換される(`なるほど…。` → `なるほど....`)
89
+ - 数字は漢字に変換される(`1,100円` → `千百円`、`52.34` → `五十二点三四`)
90
+ - 読点や疑問符等の位置・個数等は保持される(`??あ、、!!!` → `??あ,,!!!`)
91
+ """
92
+ res = unicodedata.normalize("NFKC", text) # ここでアルファベットは半角になる
93
+ res = japanese_convert_numbers_to_words(res) # 「100円」→「百円」等
94
+ # 「~」と「~」も長音記号として扱う
95
+ res = res.replace("~", "ー")
96
+ res = res.replace("~", "ー")
97
+
98
+ res = replace_punctuation(res) # 句読点等正規化、読めない文字を削除
99
+
100
+ # 結合文字の濁点・半濁点を削除
101
+ # 通常の「ば」等はそのままのこされる、「あ゛」は上で「あ゙」になりここで「あ」になる
102
+ res = res.replace("\u3099", "") # 結合文字の濁点を削除、る゙ → る
103
+ res = res.replace("\u309A", "") # 結合文字の半濁点を削除、な゚ → な
104
+ return res
105
+
106
+
107
+ def replace_punctuation(text: str) -> str:
108
+ """句読点等を「.」「,」「!」「?」「'」「-」に正規化し、OpenJTalkで読みが取得できるもののみ残す:
109
+ 漢字・平仮名・カタカナ、アルファベット、ギリシャ文字
110
+ """
111
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
112
+
113
+ # 句読点を辞書で置換
114
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
115
+
116
+ replaced_text = re.sub(
117
+ # ↓ ひらがな、カタカナ、漢字
118
+ r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
119
+ # ↓ 半角アルファベット(大文字と小文字)
120
+ + r"\u0041-\u005A\u0061-\u007A"
121
+ # ↓ 全角アルファベット(大文字と小文字)
122
+ + r"\uFF21-\uFF3A\uFF41-\uFF5A"
123
+ # ↓ ギリシャ文字
124
+ + r"\u0370-\u03FF\u1F00-\u1FFF"
125
+ # ↓ "!", "?", "…", ",", ".", "'", "-", 但し`…`はすでに`...`に変換されている
126
+ + "".join(punctuation) + r"]+",
127
+ # 上述以外の文字を削除
128
+ "",
129
+ replaced_text,
130
+ )
131
+
132
+ return replaced_text
133
+
134
+
135
+ _NUMBER_WITH_SEPARATOR_RX = re.compile("[0-9]{1,3}(,[0-9]{3})+")
136
+ _CURRENCY_MAP = {"$": "ドル", "¥": "円", "£": "ポンド", "€": "ユーロ"}
137
+ _CURRENCY_RX = re.compile(r"([$¥£€])([0-9.]*[0-9])")
138
+ _NUMBER_RX = re.compile(r"[0-9]+(\.[0-9]+)?")
139
+
140
+
141
+ def japanese_convert_numbers_to_words(text: str) -> str:
142
+ res = _NUMBER_WITH_SEPARATOR_RX.sub(lambda m: m[0].replace(",", ""), text)
143
+ res = _CURRENCY_RX.sub(lambda m: m[2] + _CURRENCY_MAP.get(m[1], m[1]), res)
144
+ res = _NUMBER_RX.sub(lambda m: num2words(m[0], lang="ja"), res)
145
+ return res
146
+
147
+
148
+ def g2p(
149
+ norm_text: str, use_jp_extra: bool = True
150
+ ) -> tuple[list[str], list[int], list[int]]:
151
+ """
152
+ 他で使われるメインの関数。`text_normalize()`で正規化された`norm_text`を受け取り、
153
+ - phones: 音素のリスト(ただし`!`や`,`や`.`等punctuationが含まれうる)
154
+ - tones: アクセントのリスト、0(低)と1(高)からなり、phonesと同じ長さ
155
+ - word2ph: 元のテキストの各文字に音素が何個割り当てられるかを表すリスト
156
+ のタプルを返す。
157
+ ただし`phones`と`tones`の最初と終わりに`_`が入り、応じて`word2ph`の最初と最後に1が追加される。
158
+ use_jp_extra: Falseの場合、「ん」の音素を「N」ではなく「n」とする。
159
+ """
160
+ # pyopenjtalkのフルコンテキストラベルを使ってアクセントを取り出すと、punctuationの位置が消えてしまい情報が失われてしまう:
161
+ # 「こんにちは、世界。」と「こんにちは!世界。」と「こんにちは!!!???世界……。」は全て同じになる。
162
+ # よって、まずpunctuation無しの音素とアクセントのリストを作り、
163
+ # それとは別にpyopenjtalk.run_frontend()で得られる音素リスト(こちらはpunctuationが保持される)を使い、
164
+ # アクセント割当をしなおすことによってpunctuationを含めた音素とアクセントのリストを作る。
165
+
166
+ # punctuationがすべて消えた、音素とアクセントのタプルのリスト(「ん」は「N」)
167
+ phone_tone_list_wo_punct = g2phone_tone_wo_punct(norm_text)
168
+
169
+ # sep_text: 単語単位の単語のリスト
170
+ # sep_kata: 単語単位の単語のカタカナ読みのリスト
171
+ sep_text, sep_kata = text2sep_kata(norm_text)
172
+
173
+ # sep_phonemes: 各単語ごとの音素のリストのリスト
174
+ sep_phonemes = handle_long([kata2phoneme_list(i) for i in sep_kata])
175
+
176
+ # phone_w_punct: sep_phonemesを結合した、punctuationを元のまま保持した音素列
177
+ phone_w_punct: list[str] = []
178
+ for i in sep_phonemes:
179
+ phone_w_punct += i
180
+
181
+ # punctuation無しのアクセント情報を使って、punctuationを含めたアクセント情報を作る
182
+ phone_tone_list = align_tones(phone_w_punct, phone_tone_list_wo_punct)
183
+ # logger.debug(f"phone_tone_list:\n{phone_tone_list}")
184
+ # word2phは厳密な解答は不可能なので(「今日」「眼鏡」等の熟字訓が存在)、
185
+ # Bert-VITS2では、単語単位の分割を使って、単語の文字ごとにだいたい均等に音素を分配する
186
+
187
+ # sep_textから、各単語を1文字1文字分割して、文字のリスト(のリスト)を作る
188
+ sep_tokenized: list[list[str]] = []
189
+ for i in sep_text:
190
+ if i not in punctuation:
191
+ sep_tokenized.append(
192
+ tokenizer.tokenize(i)
193
+ ) # ここでおそらく`i`が文字単位に分割される
194
+ else:
195
+ sep_tokenized.append([i])
196
+
197
+ # 各単語について、音素の数と文字の数を比較して、均等っぽく分配する
198
+ word2ph = []
199
+ for token, phoneme in zip(sep_tokenized, sep_phonemes):
200
+ phone_len = len(phoneme)
201
+ word_len = len(token)
202
+ word2ph += distribute_phone(phone_len, word_len)
203
+
204
+ # 最初と最後に`_`記号を追加、アクセントは0(低)、word2phもそれに合わせて追加
205
+ phone_tone_list = [("_", 0)] + phone_tone_list + [("_", 0)]
206
+ word2ph = [1] + word2ph + [1]
207
+
208
+ phones = [phone for phone, _ in phone_tone_list]
209
+ tones = [tone for _, tone in phone_tone_list]
210
+
211
+ assert len(phones) == sum(word2ph), f"{len(phones)} != {sum(word2ph)}"
212
+
213
+ # use_jp_extraでない場合は「N」を「n」に変換
214
+ if not use_jp_extra:
215
+ phones = [phone if phone != "N" else "n" for phone in phones]
216
+
217
+ return phones, tones, word2ph
218
+
219
+
220
+ def g2kata_tone(norm_text: str) -> list[tuple[str, int]]:
221
+ phones, tones, _ = g2p(norm_text, use_jp_extra=True)
222
+ return phone_tone2kata_tone(list(zip(phones, tones)))
223
+
224
+
225
+ def phone_tone2kata_tone(phone_tone: list[tuple[str, int]]) -> list[tuple[str, int]]:
226
+ """phone_toneをのphone部分をカタカナに変換する。ただし最初と最後の("_", 0)は無視"""
227
+ phone_tone = phone_tone[1:] # 最初の("_", 0)を無視
228
+ phones = [phone for phone, _ in phone_tone]
229
+ tones = [tone for _, tone in phone_tone]
230
+ result: list[tuple[str, int]] = []
231
+ current_mora = ""
232
+ for phone, next_phone, tone, next_tone in zip(phones, phones[1:], tones, tones[1:]):
233
+ # zipの関係で最後の("_", 0)は無視されている
234
+ if phone in punctuation:
235
+ result.append((phone, tone))
236
+ continue
237
+ if phone in COSONANTS: # n以外の子音の場合
238
+ assert current_mora == "", f"Unexpected {phone} after {current_mora}"
239
+ assert tone == next_tone, f"Unexpected {phone} tone {tone} != {next_tone}"
240
+ current_mora = phone
241
+ else:
242
+ # phoneが母音もしくは「N」
243
+ current_mora += phone
244
+ result.append((mora_phonemes_to_mora_kata[current_mora], tone))
245
+ current_mora = ""
246
+ return result
247
+
248
+
249
+ def kata_tone2phone_tone(kata_tone: list[tuple[str, int]]) -> list[tuple[str, int]]:
250
+ """`phone_tone2kata_tone()`の逆。"""
251
+ result: list[tuple[str, int]] = [("_", 0)]
252
+ for mora, tone in kata_tone:
253
+ if mora in punctuation:
254
+ result.append((mora, tone))
255
+ else:
256
+ cosonant, vowel = mora_kata_to_mora_phonemes[mora]
257
+ if cosonant is None:
258
+ result.append((vowel, tone))
259
+ else:
260
+ result.append((cosonant, tone))
261
+ result.append((vowel, tone))
262
+ result.append(("_", 0))
263
+ return result
264
+
265
+
266
+ def g2phone_tone_wo_punct(text: str) -> list[tuple[str, int]]:
267
+ """
268
+ テキストに対して、音素とアクセント(0か1)のペアのリストを返す。
269
+ ただし「!」「.」「?」等の非音素記号(punctuation)は全て消える(ポーズ記号も残さない)。
270
+ 非音素記号を含める処理は`align_tones()`で行われる。
271
+ また「っ」は「q」に、「ん」は「N」に変換される。
272
+ 例: "こんにちは、世界ー。。元気?!" →
273
+ [('k', 0), ('o', 0), ('N', 1), ('n', 1), ('i', 1), ('ch', 1), ('i', 1), ('w', 1), ('a', 1), ('s', 1), ('e', 1), ('k', 0), ('a', 0), ('i', 0), ('i', 0), ('g', 1), ('e', 1), ('N', 0), ('k', 0), ('i', 0)]
274
+ """
275
+ prosodies = pyopenjtalk_g2p_prosody(text, drop_unvoiced_vowels=True)
276
+ # logger.debug(f"prosodies: {prosodies}")
277
+ result: list[tuple[str, int]] = []
278
+ current_phrase: list[tuple[str, int]] = []
279
+ current_tone = 0
280
+ for i, letter in enumerate(prosodies):
281
+ # 特殊記号の処理
282
+
283
+ # 文頭記号、無視する
284
+ if letter == "^":
285
+ assert i == 0, "Unexpected ^"
286
+ # アクセント句の終わりに来る記号
287
+ elif letter in ("$", "?", "_", "#"):
288
+ # 保持しているフレーズを、アクセント数値を0-1に修正し結果に追加
289
+ result.extend(fix_phone_tone(current_phrase))
290
+ # 末尾に来る終了記号、無視(文中の疑問文は`_`になる)
291
+ if letter in ("$", "?"):
292
+ assert i == len(prosodies) - 1, f"Unexpected {letter}"
293
+ # あとは"_"(ポーズ)と"#"(アクセント句の境界)のみ
294
+ # これらは残さず、次のアクセント句に備える。
295
+ current_phrase = []
296
+ # 0を基準点にしてそこから上昇・下降する(負の場合は上の`fix_phone_tone`で直る)
297
+ current_tone = 0
298
+ # アクセント上昇記号
299
+ elif letter == "[":
300
+ current_tone = current_tone + 1
301
+ # アクセント下降記号
302
+ elif letter == "]":
303
+ current_tone = current_tone - 1
304
+ # それ以外は通常の音素
305
+ else:
306
+ if letter == "cl": # 「っ」の処理
307
+ letter = "q"
308
+ # elif letter == "N": # 「ん」の処理
309
+ # letter = "n"
310
+ current_phrase.append((letter, current_tone))
311
+ return result
312
+
313
+
314
+ def text2sep_kata(norm_text: str) -> tuple[list[str], list[str]]:
315
+ """
316
+ `text_normalize`で正規化済みの`norm_text`を受け取り、それを単語分割し、
317
+ 分割された単語リストとその読み(カタカナor記号1文字)のリストのタプルを返す。
318
+ 単語分割結果は、`g2p()`の`word2ph`で1文字あたりに割り振る音素記号の数を決めるために使う。
319
+ 例:
320
+ `私はそう思う!って感じ?` →
321
+ ["私", "は", "そう", "思う", "!", "って", "感じ", "?"], ["ワタシ", "ワ", "ソー", "オモウ", "!", "ッテ", "カンジ", "?"]
322
+ """
323
+ # parsed: OpenJTalkの解析結果
324
+ parsed = pyopenjtalk.run_frontend(norm_text)
325
+ sep_text: list[str] = []
326
+ sep_kata: list[str] = []
327
+ for parts in parsed:
328
+ # word: 実際の単語の文字列
329
+ # yomi: その読み、但し無声化サインの`’`は除去
330
+ word, yomi = replace_punctuation(parts["string"]), parts["pron"].replace(
331
+ "’", ""
332
+ )
333
+ """
334
+ ここで`yomi`の取りうる値は以下の通りのはず。
335
+ - `word`が通常単語 → 通常の読み(カタカナ)
336
+ (カタカナからなり、長音記号も含みうる、`アー` 等)
337
+ - `word`が`ー` から始まる → `ーラー` や `ーーー` など
338
+ - `word`が句読点や空白等 → `、`
339
+ - `word`が`?` → `?`(全角になる)
340
+ 他にも`word`が読めないキリル文字アラビア文字等が来ると`、`になるが、正規化でこの場合は起きないはず。
341
+ また元のコードでは`yomi`が空白の場合の処理があったが、これは起きないはず。
342
+ 処理すべきは`yomi`が`、`の場合のみのはず。
343
+ """
344
+ assert yomi != "", f"Empty yomi: {word}"
345
+ if yomi == "、":
346
+ # wordは正規化されているので、`.`, `,`, `!`, `'`, `-`, `--` のいずれか
347
+ if word not in (
348
+ ".",
349
+ ",",
350
+ "!",
351
+ "'",
352
+ "-",
353
+ "--",
354
+ ):
355
+ # ここはpyopenjtalkが読めない文字等のときに起こる
356
+ raise ValueError(f"Cannot read: {word} in:\n{norm_text}")
357
+ # yomiは元の記号のままに変更
358
+ yomi = word
359
+ elif yomi == "?":
360
+ assert word == "?", f"yomi `?` comes from: {word}"
361
+ yomi = "?"
362
+ sep_text.append(word)
363
+ sep_kata.append(yomi)
364
+ return sep_text, sep_kata
365
+
366
+
367
+ # ESPnetの実装から引用、変更点無し。「ん」は「N」なことに注意。
368
+ # https://github.com/espnet/espnet/blob/master/espnet2/text/phoneme_tokenizer.py
369
+ def pyopenjtalk_g2p_prosody(text: str, drop_unvoiced_vowels: bool = True) -> list[str]:
370
+ """Extract phoneme + prosoody symbol sequence from input full-context labels.
371
+
372
+ The algorithm is based on `Prosodic features control by symbols as input of
373
+ sequence-to-sequence acoustic modeling for neural TTS`_ with some r9y9's tweaks.
374
+
375
+ Args:
376
+ text (str): Input text.
377
+ drop_unvoiced_vowels (bool): whether to drop unvoiced vowels.
378
+
379
+ Returns:
380
+ List[str]: List of phoneme + prosody symbols.
381
+
382
+ Examples:
383
+ >>> from espnet2.text.phoneme_tokenizer import pyopenjtalk_g2p_prosody
384
+ >>> pyopenjtalk_g2p_prosody("こんにちは。")
385
+ ['^', 'k', 'o', '[', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', '$']
386
+
387
+ .. _`Prosodic features control by symbols as input of sequence-to-sequence acoustic
388
+ modeling for neural TTS`: https://doi.org/10.1587/transinf.2020EDP7104
389
+
390
+ """
391
+ labels = pyopenjtalk.make_label(pyopenjtalk.run_frontend(text))
392
+ N = len(labels)
393
+
394
+ phones = []
395
+ for n in range(N):
396
+ lab_curr = labels[n]
397
+
398
+ # current phoneme
399
+ p3 = re.search(r"\-(.*?)\+", lab_curr).group(1)
400
+ # deal unvoiced vowels as normal vowels
401
+ if drop_unvoiced_vowels and p3 in "AEIOU":
402
+ p3 = p3.lower()
403
+
404
+ # deal with sil at the beginning and the end of text
405
+ if p3 == "sil":
406
+ assert n == 0 or n == N - 1
407
+ if n == 0:
408
+ phones.append("^")
409
+ elif n == N - 1:
410
+ # check question form or not
411
+ e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr)
412
+ if e3 == 0:
413
+ phones.append("$")
414
+ elif e3 == 1:
415
+ phones.append("?")
416
+ continue
417
+ elif p3 == "pau":
418
+ phones.append("_")
419
+ continue
420
+ else:
421
+ phones.append(p3)
422
+
423
+ # accent type and position info (forward or backward)
424
+ a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr)
425
+ a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr)
426
+ a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr)
427
+
428
+ # number of mora in accent phrase
429
+ f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr)
430
+
431
+ a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1])
432
+ # accent phrase border
433
+ if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl":
434
+ phones.append("#")
435
+ # pitch falling
436
+ elif a1 == 0 and a2_next == a2 + 1 and a2 != f1:
437
+ phones.append("]")
438
+ # pitch rising
439
+ elif a2 == 1 and a2_next == 2:
440
+ phones.append("[")
441
+
442
+ return phones
443
+
444
+
445
+ def _numeric_feature_by_regex(regex, s):
446
+ match = re.search(regex, s)
447
+ if match is None:
448
+ return -50
449
+ return int(match.group(1))
450
+
451
+
452
+ def fix_phone_tone(phone_tone_list: list[tuple[str, int]]) -> list[tuple[str, int]]:
453
+ """
454
+ `phone_tone_list`のtone(アクセントの値)を0か1の範囲に修正する。
455
+ 例: [(a, 0), (i, -1), (u, -1)] → [(a, 1), (i, 0), (u, 0)]
456
+ """
457
+ tone_values = set(tone for _, tone in phone_tone_list)
458
+ if len(tone_values) == 1:
459
+ assert tone_values == {0}, tone_values
460
+ return phone_tone_list
461
+ elif len(tone_values) == 2:
462
+ if tone_values == {0, 1}:
463
+ return phone_tone_list
464
+ elif tone_values == {-1, 0}:
465
+ return [
466
+ (letter, 0 if tone == -1 else 1) for letter, tone in phone_tone_list
467
+ ]
468
+ else:
469
+ raise ValueError(f"Unexpected tone values: {tone_values}")
470
+ else:
471
+ raise ValueError(f"Unexpected tone values: {tone_values}")
472
+
473
+
474
+ def distribute_phone(n_phone: int, n_word: int) -> list[int]:
475
+ """
476
+ 左から右に1ずつ振り分け、次にまた左から右に1ずつ増やし、というふうに、
477
+ 音素の数`n_phone`を単語の数`n_word`に分配する。
478
+ """
479
+ phones_per_word = [0] * n_word
480
+ for _ in range(n_phone):
481
+ min_tasks = min(phones_per_word)
482
+ min_index = phones_per_word.index(min_tasks)
483
+ phones_per_word[min_index] += 1
484
+ return phones_per_word
485
+
486
+
487
+ def handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]:
488
+ for i in range(len(sep_phonemes)):
489
+ if sep_phonemes[i][0] == "ー":
490
+ sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
491
+ if "ー" in sep_phonemes[i]:
492
+ for j in range(len(sep_phonemes[i])):
493
+ if sep_phonemes[i][j] == "ー":
494
+ sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
495
+ return sep_phonemes
496
+
497
+
498
+ tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese-char-wwm")
499
+
500
+
501
+ def align_tones(
502
+ phones_with_punct: list[str], phone_tone_list: list[tuple[str, int]]
503
+ ) -> list[tuple[str, int]]:
504
+ """
505
+ 例:
506
+ …私は、、そう思う。
507
+ phones_with_punct:
508
+ [".", ".", ".", "w", "a", "t", "a", "sh", "i", "w", "a", ",", ",", "s", "o", "o", "o", "m", "o", "u", "."]
509
+ phone_tone_list:
510
+ [("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), ("_", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0))]
511
+ Return:
512
+ [(".", 0), (".", 0), (".", 0), ("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), (",", 0), (",", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0), (".", 0)]
513
+ """
514
+ result: list[tuple[str, int]] = []
515
+ tone_index = 0
516
+ for phone in phones_with_punct:
517
+ if tone_index >= len(phone_tone_list):
518
+ # 余ったpunctuationがある場合 → (punctuation, 0)を追加
519
+ result.append((phone, 0))
520
+ elif phone == phone_tone_list[tone_index][0]:
521
+ # phone_tone_listの現在の音素と一致する場合 → toneをそこから取得、(phone, tone)を追加
522
+ result.append((phone, phone_tone_list[tone_index][1]))
523
+ # 探すindexを1つ進める
524
+ tone_index += 1
525
+ elif phone in punctuation:
526
+ # phoneがpunctuationの場合 → (phone, 0)を追加
527
+ result.append((phone, 0))
528
+ else:
529
+ logger.debug(f"phones: {phones_with_punct}")
530
+ logger.debug(f"phone_tone_list: {phone_tone_list}")
531
+ logger.debug(f"result: {result}")
532
+ logger.debug(f"tone_index: {tone_index}")
533
+ logger.debug(f"phone: {phone}")
534
+ raise ValueError(f"Unexpected phone: {phone}")
535
+ return result
536
+
537
+
538
+ def kata2phoneme_list(text: str) -> list[str]:
539
+ """
540
+ 原則カタカナの`text`を受け取り、それをそのままいじらずに音素記号のリストに変換。
541
+ 注意点:
542
+ - punctuationが来た場合(punctuationが1文字の場合がありうる)、処理せず1文字のリストを返す
543
+ - 冒頭に続く「ー」はそのまま「ー」のままにする(`handle_long()`で処理される)
544
+ - 文中の「ー」は前の音素記号の最後の音素記号に変換される。
545
+ 例:
546
+ `ーーソーナノカーー` → ["ー", "ー", "s", "o", "o", "n", "a", "n", "o", "k", "a", "a", "a"]
547
+ `?` → ["?"]
548
+ """
549
+ if text in punctuation:
550
+ return [text]
551
+ elif text == "--":
552
+ return ["-", "-"]
553
+ # `text`がカタカナ(`ー`含む)のみからなるかどうかをチェック
554
+ if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None:
555
+ raise ValueError(f"Input must be katakana only: {text}")
556
+ sorted_keys = sorted(mora_kata_to_mora_phonemes.keys(), key=len, reverse=True)
557
+ pattern = "|".join(map(re.escape, sorted_keys))
558
+
559
+ def mora2phonemes(mora: str) -> str:
560
+ cosonant, vowel = mora_kata_to_mora_phonemes[mora]
561
+ if cosonant is None:
562
+ return f" {vowel}"
563
+ return f" {cosonant} {vowel}"
564
+
565
+ spaced_phonemes = re.sub(pattern, lambda m: mora2phonemes(m.group()), text)
566
+
567
+ # 長音記号「ー」の処理
568
+ long_pattern = r"(\w)(ー*)"
569
+ long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2))
570
+ spaced_phonemes = re.sub(long_pattern, long_replacement, spaced_phonemes)
571
+ return spaced_phonemes.strip().split(" ")
572
+
573
+
574
+ if __name__ == "__main__":
575
+ tokenizer = AutoTokenizer.from_pretrained("./bert/deberta-v2-large-japanese")
576
+ text = "hello,こんにちは、世界ー!……"
577
+ from text.japanese_bert import get_bert_feature
578
+
579
+ text = text_normalize(text)
580
+ print(text)
581
+
582
+ phones, tones, word2ph = g2p(text)
583
+ bert = get_bert_feature(text, word2ph)
584
+
585
+ print(phones, tones, word2ph, bert.shape)
text/japanese_bert.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import torch
4
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
5
+
6
+ from config import config
7
+ from text.japanese import text2sep_kata
8
+
9
+ LOCAL_PATH = "./bert/deberta-v2-large-japanese-char-wwm"
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(LOCAL_PATH)
12
+
13
+ models = dict()
14
+
15
+
16
+ def get_bert_feature(
17
+ text,
18
+ word2ph,
19
+ device=config.bert_gen_config.device,
20
+ assist_text=None,
21
+ assist_text_weight=0.7,
22
+ ):
23
+ text = "".join(text2sep_kata(text)[0])
24
+ if assist_text:
25
+ assist_text = "".join(text2sep_kata(assist_text)[0])
26
+ if (
27
+ sys.platform == "darwin"
28
+ and torch.backends.mps.is_available()
29
+ and device == "cpu"
30
+ ):
31
+ device = "mps"
32
+ if not device:
33
+ device = "cuda"
34
+ if device == "cuda" and not torch.cuda.is_available():
35
+ device = "cpu"
36
+ if device not in models.keys():
37
+ models[device] = AutoModelForMaskedLM.from_pretrained(LOCAL_PATH).to(device)
38
+ with torch.no_grad():
39
+ inputs = tokenizer(text, return_tensors="pt")
40
+ for i in inputs:
41
+ inputs[i] = inputs[i].to(device)
42
+ res = models[device](**inputs, output_hidden_states=True)
43
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
44
+ if assist_text:
45
+ style_inputs = tokenizer(assist_text, return_tensors="pt")
46
+ for i in style_inputs:
47
+ style_inputs[i] = style_inputs[i].to(device)
48
+ style_res = models[device](**style_inputs, output_hidden_states=True)
49
+ style_res = torch.cat(style_res["hidden_states"][-3:-2], -1)[0].cpu()
50
+ style_res_mean = style_res.mean(0)
51
+
52
+ assert len(word2ph) == len(text) + 2, text
53
+ word2phone = word2ph
54
+ phone_level_feature = []
55
+ for i in range(len(word2phone)):
56
+ if assist_text:
57
+ repeat_feature = (
58
+ res[i].repeat(word2phone[i], 1) * (1 - assist_text_weight)
59
+ + style_res_mean.repeat(word2phone[i], 1) * assist_text_weight
60
+ )
61
+ else:
62
+ repeat_feature = res[i].repeat(word2phone[i], 1)
63
+ phone_level_feature.append(repeat_feature)
64
+
65
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
66
+
67
+ return phone_level_feature.T
text/japanese_mora_list.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VOICEVOXのソースコードからお借りして最低限に改造したコード。
3
+ https://github.com/VOICEVOX/voicevox_engine/blob/master/voicevox_engine/tts_pipeline/mora_list.py
4
+ """
5
+
6
+ """
7
+ 以下のモーラ対応表はOpenJTalkのソースコードから取得し、
8
+ カタカナ表記とモーラが一対一対応するように改造した。
9
+ ライセンス表記:
10
+ -----------------------------------------------------------------
11
+ The Japanese TTS System "Open JTalk"
12
+ developed by HTS Working Group
13
+ http://open-jtalk.sourceforge.net/
14
+ -----------------------------------------------------------------
15
+
16
+ Copyright (c) 2008-2014 Nagoya Institute of Technology
17
+ Department of Computer Science
18
+
19
+ All rights reserved.
20
+
21
+ Redistribution and use in source and binary forms, with or
22
+ without modification, are permitted provided that the following
23
+ conditions are met:
24
+
25
+ - Redistributions of source code must retain the above copyright
26
+ notice, this list of conditions and the following disclaimer.
27
+ - Redistributions in binary form must reproduce the above
28
+ copyright notice, this list of conditions and the following
29
+ disclaimer in the documentation and/or other materials provided
30
+ with the distribution.
31
+ - Neither the name of the HTS working group nor the names of its
32
+ contributors may be used to endorse or promote products derived
33
+ from this software without specific prior written permission.
34
+
35
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
36
+ CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
37
+ INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
38
+ MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
39
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
40
+ BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
41
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
42
+ TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
43
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
44
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
45
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
46
+ OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
47
+ POSSIBILITY OF SUCH DAMAGE.
48
+ """
49
+ from typing import Optional
50
+
51
+ # (カタカナ, 子音, 母音)の順。子音がない場合はNoneを入れる。
52
+ # 但し「ン」と「ッ」は母音のみという扱いで、「ン」は「N」、「ッ」は「q」とする。
53
+ # (元々「ッ」は「cl」)
54
+ # また「デェ = dy e」はpyopenjtalkの出力(de e)と合わないため削除
55
+ _mora_list_minimum: list[tuple[str, Optional[str], str]] = [
56
+ ("ヴォ", "v", "o"),
57
+ ("ヴェ", "v", "e"),
58
+ ("ヴィ", "v", "i"),
59
+ ("ヴァ", "v", "a"),
60
+ ("ヴ", "v", "u"),
61
+ ("ン", None, "N"),
62
+ ("ワ", "w", "a"),
63
+ ("ロ", "r", "o"),
64
+ ("レ", "r", "e"),
65
+ ("ル", "r", "u"),
66
+ ("リョ", "ry", "o"),
67
+ ("リュ", "ry", "u"),
68
+ ("リャ", "ry", "a"),
69
+ ("リェ", "ry", "e"),
70
+ ("リ", "r", "i"),
71
+ ("ラ", "r", "a"),
72
+ ("ヨ", "y", "o"),
73
+ ("ユ", "y", "u"),
74
+ ("ヤ", "y", "a"),
75
+ ("モ", "m", "o"),
76
+ ("メ", "m", "e"),
77
+ ("ム", "m", "u"),
78
+ ("ミョ", "my", "o"),
79
+ ("ミュ", "my", "u"),
80
+ ("ミャ", "my", "a"),
81
+ ("ミェ", "my", "e"),
82
+ ("ミ", "m", "i"),
83
+ ("マ", "m", "a"),
84
+ ("ポ", "p", "o"),
85
+ ("ボ", "b", "o"),
86
+ ("ホ", "h", "o"),
87
+ ("ペ", "p", "e"),
88
+ ("ベ", "b", "e"),
89
+ ("ヘ", "h", "e"),
90
+ ("プ", "p", "u"),
91
+ ("ブ", "b", "u"),
92
+ ("フォ", "f", "o"),
93
+ ("フェ", "f", "e"),
94
+ ("フィ", "f", "i"),
95
+ ("ファ", "f", "a"),
96
+ ("フ", "f", "u"),
97
+ ("ピョ", "py", "o"),
98
+ ("ピュ", "py", "u"),
99
+ ("ピャ", "py", "a"),
100
+ ("ピェ", "py", "e"),
101
+ ("ピ", "p", "i"),
102
+ ("ビョ", "by", "o"),
103
+ ("ビュ", "by", "u"),
104
+ ("ビャ", "by", "a"),
105
+ ("ビェ", "by", "e"),
106
+ ("ビ", "b", "i"),
107
+ ("ヒョ", "hy", "o"),
108
+ ("ヒュ", "hy", "u"),
109
+ ("ヒャ", "hy", "a"),
110
+ ("ヒェ", "hy", "e"),
111
+ ("ヒ", "h", "i"),
112
+ ("パ", "p", "a"),
113
+ ("バ", "b", "a"),
114
+ ("ハ", "h", "a"),
115
+ ("ノ", "n", "o"),
116
+ ("ネ", "n", "e"),
117
+ ("ヌ", "n", "u"),
118
+ ("ニョ", "ny", "o"),
119
+ ("ニュ", "ny", "u"),
120
+ ("ニャ", "ny", "a"),
121
+ ("ニェ", "ny", "e"),
122
+ ("ニ", "n", "i"),
123
+ ("ナ", "n", "a"),
124
+ ("ドゥ", "d", "u"),
125
+ ("ド", "d", "o"),
126
+ ("トゥ", "t", "u"),
127
+ ("ト", "t", "o"),
128
+ ("デョ", "dy", "o"),
129
+ ("デュ", "dy", "u"),
130
+ ("デャ", "dy", "a"),
131
+ # ("デェ", "dy", "e"),
132
+ ("ディ", "d", "i"),
133
+ ("デ", "d", "e"),
134
+ ("テョ", "ty", "o"),
135
+ ("テュ", "ty", "u"),
136
+ ("テャ", "ty", "a"),
137
+ ("ティ", "t", "i"),
138
+ ("テ", "t", "e"),
139
+ ("ツォ", "ts", "o"),
140
+ ("ツェ", "ts", "e"),
141
+ ("ツィ", "ts", "i"),
142
+ ("ツァ", "ts", "a"),
143
+ ("ツ", "ts", "u"),
144
+ ("ッ", None, "q"), # 「cl」から「q」に変更
145
+ ("チョ", "ch", "o"),
146
+ ("チュ", "ch", "u"),
147
+ ("チャ", "ch", "a"),
148
+ ("チェ", "ch", "e"),
149
+ ("チ", "ch", "i"),
150
+ ("ダ", "d", "a"),
151
+ ("タ", "t", "a"),
152
+ ("ゾ", "z", "o"),
153
+ ("ソ", "s", "o"),
154
+ ("ゼ", "z", "e"),
155
+ ("セ", "s", "e"),
156
+ ("ズィ", "z", "i"),
157
+ ("ズ", "z", "u"),
158
+ ("スィ", "s", "i"),
159
+ ("ス", "s", "u"),
160
+ ("ジョ", "j", "o"),
161
+ ("ジュ", "j", "u"),
162
+ ("ジャ", "j", "a"),
163
+ ("ジェ", "j", "e"),
164
+ ("ジ", "j", "i"),
165
+ ("ショ", "sh", "o"),
166
+ ("シュ", "sh", "u"),
167
+ ("シャ", "sh", "a"),
168
+ ("シェ", "sh", "e"),
169
+ ("シ", "sh", "i"),
170
+ ("ザ", "z", "a"),
171
+ ("サ", "s", "a"),
172
+ ("ゴ", "g", "o"),
173
+ ("コ", "k", "o"),
174
+ ("ゲ", "g", "e"),
175
+ ("ケ", "k", "e"),
176
+ ("グヮ", "gw", "a"),
177
+ ("グ", "g", "u"),
178
+ ("クヮ", "kw", "a"),
179
+ ("ク", "k", "u"),
180
+ ("ギョ", "gy", "o"),
181
+ ("ギュ", "gy", "u"),
182
+ ("ギャ", "gy", "a"),
183
+ ("ギェ", "gy", "e"),
184
+ ("ギ", "g", "i"),
185
+ ("キョ", "ky", "o"),
186
+ ("キュ", "ky", "u"),
187
+ ("キャ", "ky", "a"),
188
+ ("キェ", "ky", "e"),
189
+ ("キ", "k", "i"),
190
+ ("ガ", "g", "a"),
191
+ ("カ", "k", "a"),
192
+ ("オ", None, "o"),
193
+ ("エ", None, "e"),
194
+ ("ウォ", "w", "o"),
195
+ ("ウェ", "w", "e"),
196
+ ("ウィ", "w", "i"),
197
+ ("ウ", None, "u"),
198
+ ("イェ", "y", "e"),
199
+ ("イ", None, "i"),
200
+ ("ア", None, "a"),
201
+ ]
202
+ _mora_list_additional: list[tuple[str, Optional[str], str]] = [
203
+ ("ヴョ", "by", "o"),
204
+ ("ヴュ", "by", "u"),
205
+ ("ヴャ", "by", "a"),
206
+ ("ヲ", None, "o"),
207
+ ("ヱ", None, "e"),
208
+ ("ヰ", None, "i"),
209
+ ("ヮ", "w", "a"),
210
+ ("ョ", "y", "o"),
211
+ ("ュ", "y", "u"),
212
+ ("ヅ", "z", "u"),
213
+ ("ヂ", "j", "i"),
214
+ ("ヶ", "k", "e"),
215
+ ("ャ", "y", "a"),
216
+ ("ォ", None, "o"),
217
+ ("ェ", None, "e"),
218
+ ("ゥ", None, "u"),
219
+ ("ィ", None, "i"),
220
+ ("ァ", None, "a"),
221
+ ]
222
+
223
+ # 例: "vo" -> "ヴォ", "a" -> "ア"
224
+ mora_phonemes_to_mora_kata: dict[str, str] = {
225
+ (consonant or "") + vowel: kana for [kana, consonant, vowel] in _mora_list_minimum
226
+ }
227
+
228
+ # 例: "ヴォ" -> ("v", "o"), "ア" -> (None, "a")
229
+ mora_kata_to_mora_phonemes: dict[str, tuple[Optional[str], str]] = {
230
+ kana: (consonant, vowel)
231
+ for [kana, consonant, vowel] in _mora_list_minimum + _mora_list_additional
232
+ }
text/opencpop-strict.txt ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a AA a
2
+ ai AA ai
3
+ an AA an
4
+ ang AA ang
5
+ ao AA ao
6
+ ba b a
7
+ bai b ai
8
+ ban b an
9
+ bang b ang
10
+ bao b ao
11
+ bei b ei
12
+ ben b en
13
+ beng b eng
14
+ bi b i
15
+ bian b ian
16
+ biao b iao
17
+ bie b ie
18
+ bin b in
19
+ bing b ing
20
+ bo b o
21
+ bu b u
22
+ ca c a
23
+ cai c ai
24
+ can c an
25
+ cang c ang
26
+ cao c ao
27
+ ce c e
28
+ cei c ei
29
+ cen c en
30
+ ceng c eng
31
+ cha ch a
32
+ chai ch ai
33
+ chan ch an
34
+ chang ch ang
35
+ chao ch ao
36
+ che ch e
37
+ chen ch en
38
+ cheng ch eng
39
+ chi ch ir
40
+ chong ch ong
41
+ chou ch ou
42
+ chu ch u
43
+ chua ch ua
44
+ chuai ch uai
45
+ chuan ch uan
46
+ chuang ch uang
47
+ chui ch ui
48
+ chun ch un
49
+ chuo ch uo
50
+ ci c i0
51
+ cong c ong
52
+ cou c ou
53
+ cu c u
54
+ cuan c uan
55
+ cui c ui
56
+ cun c un
57
+ cuo c uo
58
+ da d a
59
+ dai d ai
60
+ dan d an
61
+ dang d ang
62
+ dao d ao
63
+ de d e
64
+ dei d ei
65
+ den d en
66
+ deng d eng
67
+ di d i
68
+ dia d ia
69
+ dian d ian
70
+ diao d iao
71
+ die d ie
72
+ ding d ing
73
+ diu d iu
74
+ dong d ong
75
+ dou d ou
76
+ du d u
77
+ duan d uan
78
+ dui d ui
79
+ dun d un
80
+ duo d uo
81
+ e EE e
82
+ ei EE ei
83
+ en EE en
84
+ eng EE eng
85
+ er EE er
86
+ fa f a
87
+ fan f an
88
+ fang f ang
89
+ fei f ei
90
+ fen f en
91
+ feng f eng
92
+ fo f o
93
+ fou f ou
94
+ fu f u
95
+ ga g a
96
+ gai g ai
97
+ gan g an
98
+ gang g ang
99
+ gao g ao
100
+ ge g e
101
+ gei g ei
102
+ gen g en
103
+ geng g eng
104
+ gong g ong
105
+ gou g ou
106
+ gu g u
107
+ gua g ua
108
+ guai g uai
109
+ guan g uan
110
+ guang g uang
111
+ gui g ui
112
+ gun g un
113
+ guo g uo
114
+ ha h a
115
+ hai h ai
116
+ han h an
117
+ hang h ang
118
+ hao h ao
119
+ he h e
120
+ hei h ei
121
+ hen h en
122
+ heng h eng
123
+ hong h ong
124
+ hou h ou
125
+ hu h u
126
+ hua h ua
127
+ huai h uai
128
+ huan h uan
129
+ huang h uang
130
+ hui h ui
131
+ hun h un
132
+ huo h uo
133
+ ji j i
134
+ jia j ia
135
+ jian j ian
136
+ jiang j iang
137
+ jiao j iao
138
+ jie j ie
139
+ jin j in
140
+ jing j ing
141
+ jiong j iong
142
+ jiu j iu
143
+ ju j v
144
+ jv j v
145
+ juan j van
146
+ jvan j van
147
+ jue j ve
148
+ jve j ve
149
+ jun j vn
150
+ jvn j vn
151
+ ka k a
152
+ kai k ai
153
+ kan k an
154
+ kang k ang
155
+ kao k ao
156
+ ke k e
157
+ kei k ei
158
+ ken k en
159
+ keng k eng
160
+ kong k ong
161
+ kou k ou
162
+ ku k u
163
+ kua k ua
164
+ kuai k uai
165
+ kuan k uan
166
+ kuang k uang
167
+ kui k ui
168
+ kun k un
169
+ kuo k uo
170
+ la l a
171
+ lai l ai
172
+ lan l an
173
+ lang l ang
174
+ lao l ao
175
+ le l e
176
+ lei l ei
177
+ leng l eng
178
+ li l i
179
+ lia l ia
180
+ lian l ian
181
+ liang l iang
182
+ liao l iao
183
+ lie l ie
184
+ lin l in
185
+ ling l ing
186
+ liu l iu
187
+ lo l o
188
+ long l ong
189
+ lou l ou
190
+ lu l u
191
+ luan l uan
192
+ lun l un
193
+ luo l uo
194
+ lv l v
195
+ lve l ve
196
+ ma m a
197
+ mai m ai
198
+ man m an
199
+ mang m ang
200
+ mao m ao
201
+ me m e
202
+ mei m ei
203
+ men m en
204
+ meng m eng
205
+ mi m i
206
+ mian m ian
207
+ miao m iao
208
+ mie m ie
209
+ min m in
210
+ ming m ing
211
+ miu m iu
212
+ mo m o
213
+ mou m ou
214
+ mu m u
215
+ na n a
216
+ nai n ai
217
+ nan n an
218
+ nang n ang
219
+ nao n ao
220
+ ne n e
221
+ nei n ei
222
+ nen n en
223
+ neng n eng
224
+ ni n i
225
+ nian n ian
226
+ niang n iang
227
+ niao n iao
228
+ nie n ie
229
+ nin n in
230
+ ning n ing
231
+ niu n iu
232
+ nong n ong
233
+ nou n ou
234
+ nu n u
235
+ nuan n uan
236
+ nun n un
237
+ nuo n uo
238
+ nv n v
239
+ nve n ve
240
+ o OO o
241
+ ou OO ou
242
+ pa p a
243
+ pai p ai
244
+ pan p an
245
+ pang p ang
246
+ pao p ao
247
+ pei p ei
248
+ pen p en
249
+ peng p eng
250
+ pi p i
251
+ pian p ian
252
+ piao p iao
253
+ pie p ie
254
+ pin p in
255
+ ping p ing
256
+ po p o
257
+ pou p ou
258
+ pu p u
259
+ qi q i
260
+ qia q ia
261
+ qian q ian
262
+ qiang q iang
263
+ qiao q iao
264
+ qie q ie
265
+ qin q in
266
+ qing q ing
267
+ qiong q iong
268
+ qiu q iu
269
+ qu q v
270
+ qv q v
271
+ quan q van
272
+ qvan q van
273
+ que q ve
274
+ qve q ve
275
+ qun q vn
276
+ qvn q vn
277
+ ran r an
278
+ rang r ang
279
+ rao r ao
280
+ re r e
281
+ ren r en
282
+ reng r eng
283
+ ri r ir
284
+ rong r ong
285
+ rou r ou
286
+ ru r u
287
+ rua r ua
288
+ ruan r uan
289
+ rui r ui
290
+ run r un
291
+ ruo r uo
292
+ sa s a
293
+ sai s ai
294
+ san s an
295
+ sang s ang
296
+ sao s ao
297
+ se s e
298
+ sen s en
299
+ seng s eng
300
+ sha sh a
301
+ shai sh ai
302
+ shan sh an
303
+ shang sh ang
304
+ shao sh ao
305
+ she sh e
306
+ shei sh ei
307
+ shen sh en
308
+ sheng sh eng
309
+ shi sh ir
310
+ shou sh ou
311
+ shu sh u
312
+ shua sh ua
313
+ shuai sh uai
314
+ shuan sh uan
315
+ shuang sh uang
316
+ shui sh ui
317
+ shun sh un
318
+ shuo sh uo
319
+ si s i0
320
+ song s ong
321
+ sou s ou
322
+ su s u
323
+ suan s uan
324
+ sui s ui
325
+ sun s un
326
+ suo s uo
327
+ ta t a
328
+ tai t ai
329
+ tan t an
330
+ tang t ang
331
+ tao t ao
332
+ te t e
333
+ tei t ei
334
+ teng t eng
335
+ ti t i
336
+ tian t ian
337
+ tiao t iao
338
+ tie t ie
339
+ ting t ing
340
+ tong t ong
341
+ tou t ou
342
+ tu t u
343
+ tuan t uan
344
+ tui t ui
345
+ tun t un
346
+ tuo t uo
347
+ wa w a
348
+ wai w ai
349
+ wan w an
350
+ wang w ang
351
+ wei w ei
352
+ wen w en
353
+ weng w eng
354
+ wo w o
355
+ wu w u
356
+ xi x i
357
+ xia x ia
358
+ xian x ian
359
+ xiang x iang
360
+ xiao x iao
361
+ xie x ie
362
+ xin x in
363
+ xing x ing
364
+ xiong x iong
365
+ xiu x iu
366
+ xu x v
367
+ xv x v
368
+ xuan x van
369
+ xvan x van
370
+ xue x ve
371
+ xve x ve
372
+ xun x vn
373
+ xvn x vn
374
+ ya y a
375
+ yan y En
376
+ yang y ang
377
+ yao y ao
378
+ ye y E
379
+ yi y i
380
+ yin y in
381
+ ying y ing
382
+ yo y o
383
+ yong y ong
384
+ you y ou
385
+ yu y v
386
+ yv y v
387
+ yuan y van
388
+ yvan y van
389
+ yue y ve
390
+ yve y ve
391
+ yun y vn
392
+ yvn y vn
393
+ za z a
394
+ zai z ai
395
+ zan z an
396
+ zang z ang
397
+ zao z ao
398
+ ze z e
399
+ zei z ei
400
+ zen z en
401
+ zeng z eng
402
+ zha zh a
403
+ zhai zh ai
404
+ zhan zh an
405
+ zhang zh ang
406
+ zhao zh ao
407
+ zhe zh e
408
+ zhei zh ei
409
+ zhen zh en
410
+ zheng zh eng
411
+ zhi zh ir
412
+ zhong zh ong
413
+ zhou zh ou
414
+ zhu zh u
415
+ zhua zh ua
416
+ zhuai zh uai
417
+ zhuan zh uan
418
+ zhuang zh uang
419
+ zhui zh ui
420
+ zhun zh un
421
+ zhuo zh uo
422
+ zi z i0
423
+ zong z ong
424
+ zou z ou
425
+ zu z u
426
+ zuan z uan
427
+ zui z ui
428
+ zun z un
429
+ zuo z uo
text/symbols.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
+ pu_symbols = punctuation + ["SP", "UNK"]
3
+ pad = "_"
4
+
5
+ # chinese
6
+ zh_symbols = [
7
+ "E",
8
+ "En",
9
+ "a",
10
+ "ai",
11
+ "an",
12
+ "ang",
13
+ "ao",
14
+ "b",
15
+ "c",
16
+ "ch",
17
+ "d",
18
+ "e",
19
+ "ei",
20
+ "en",
21
+ "eng",
22
+ "er",
23
+ "f",
24
+ "g",
25
+ "h",
26
+ "i",
27
+ "i0",
28
+ "ia",
29
+ "ian",
30
+ "iang",
31
+ "iao",
32
+ "ie",
33
+ "in",
34
+ "ing",
35
+ "iong",
36
+ "ir",
37
+ "iu",
38
+ "j",
39
+ "k",
40
+ "l",
41
+ "m",
42
+ "n",
43
+ "o",
44
+ "ong",
45
+ "ou",
46
+ "p",
47
+ "q",
48
+ "r",
49
+ "s",
50
+ "sh",
51
+ "t",
52
+ "u",
53
+ "ua",
54
+ "uai",
55
+ "uan",
56
+ "uang",
57
+ "ui",
58
+ "un",
59
+ "uo",
60
+ "v",
61
+ "van",
62
+ "ve",
63
+ "vn",
64
+ "w",
65
+ "x",
66
+ "y",
67
+ "z",
68
+ "zh",
69
+ "AA",
70
+ "EE",
71
+ "OO",
72
+ ]
73
+ num_zh_tones = 6
74
+
75
+ # japanese
76
+ ja_symbols = [
77
+ "N",
78
+ "a",
79
+ "a:",
80
+ "b",
81
+ "by",
82
+ "ch",
83
+ "d",
84
+ "dy",
85
+ "e",
86
+ "e:",
87
+ "f",
88
+ "g",
89
+ "gy",
90
+ "h",
91
+ "hy",
92
+ "i",
93
+ "i:",
94
+ "j",
95
+ "k",
96
+ "ky",
97
+ "m",
98
+ "my",
99
+ "n",
100
+ "ny",
101
+ "o",
102
+ "o:",
103
+ "p",
104
+ "py",
105
+ "q",
106
+ "r",
107
+ "ry",
108
+ "s",
109
+ "sh",
110
+ "t",
111
+ "ts",
112
+ "ty",
113
+ "u",
114
+ "u:",
115
+ "w",
116
+ "y",
117
+ "z",
118
+ "zy",
119
+ ]
120
+ num_ja_tones = 2
121
+
122
+ # English
123
+ en_symbols = [
124
+ "aa",
125
+ "ae",
126
+ "ah",
127
+ "ao",
128
+ "aw",
129
+ "ay",
130
+ "b",
131
+ "ch",
132
+ "d",
133
+ "dh",
134
+ "eh",
135
+ "er",
136
+ "ey",
137
+ "f",
138
+ "g",
139
+ "hh",
140
+ "ih",
141
+ "iy",
142
+ "jh",
143
+ "k",
144
+ "l",
145
+ "m",
146
+ "n",
147
+ "ng",
148
+ "ow",
149
+ "oy",
150
+ "p",
151
+ "r",
152
+ "s",
153
+ "sh",
154
+ "t",
155
+ "th",
156
+ "uh",
157
+ "uw",
158
+ "V",
159
+ "w",
160
+ "y",
161
+ "z",
162
+ "zh",
163
+ ]
164
+ num_en_tones = 4
165
+
166
+ # combine all symbols
167
+ normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols))
168
+ symbols = [pad] + normal_symbols + pu_symbols
169
+ sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
170
+
171
+ # combine all tones
172
+ num_tones = num_zh_tones + num_ja_tones + num_en_tones
173
+
174
+ # language maps
175
+ language_id_map = {"ZH": 0, "JP": 1, "EN": 2}
176
+ num_languages = len(language_id_map.keys())
177
+
178
+ language_tone_start_map = {
179
+ "ZH": 0,
180
+ "JP": num_zh_tones,
181
+ "EN": num_zh_tones + num_ja_tones,
182
+ }
183
+
184
+ if __name__ == "__main__":
185
+ a = set(zh_symbols)
186
+ b = set(en_symbols)
187
+ print(sorted(a & b))
text/tone_sandhi.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+ from typing import Tuple
16
+
17
+ import jieba
18
+ from pypinyin import lazy_pinyin
19
+ from pypinyin import Style
20
+
21
+
22
+ class ToneSandhi:
23
+ def __init__(self):
24
+ self.must_neural_tone_words = {
25
+ "麻烦",
26
+ "麻利",
27
+ "鸳鸯",
28
+ "高粱",
29
+ "骨头",
30
+ "骆驼",
31
+ "马虎",
32
+ "首饰",
33
+ "馒头",
34
+ "馄饨",
35
+ "风筝",
36
+ "难为",
37
+ "队伍",
38
+ "阔气",
39
+ "闺女",
40
+ "门道",
41
+ "锄头",
42
+ "铺盖",
43
+ "铃铛",
44
+ "铁匠",
45
+ "钥匙",
46
+ "里脊",
47
+ "里头",
48
+ "部分",
49
+ "那么",
50
+ "道士",
51
+ "造化",
52
+ "迷糊",
53
+ "连累",
54
+ "这么",
55
+ "这个",
56
+ "运气",
57
+ "过去",
58
+ "软和",
59
+ "转悠",
60
+ "踏实",
61
+ "跳蚤",
62
+ "跟头",
63
+ "趔趄",
64
+ "财主",
65
+ "豆腐",
66
+ "讲究",
67
+ "记性",
68
+ "记号",
69
+ "认识",
70
+ "规矩",
71
+ "见识",
72
+ "裁缝",
73
+ "补丁",
74
+ "衣裳",
75
+ "衣服",
76
+ "衙门",
77
+ "街坊",
78
+ "行李",
79
+ "行当",
80
+ "蛤蟆",
81
+ "蘑菇",
82
+ "薄荷",
83
+ "葫芦",
84
+ "葡萄",
85
+ "萝卜",
86
+ "荸荠",
87
+ "苗条",
88
+ "苗头",
89
+ "苍蝇",
90
+ "芝麻",
91
+ "舒服",
92
+ "舒坦",
93
+ "舌头",
94
+ "自在",
95
+ "膏药",
96
+ "脾气",
97
+ "脑袋",
98
+ "脊梁",
99
+ "能耐",
100
+ "胳膊",
101
+ "胭脂",
102
+ "胡萝",
103
+ "胡琴",
104
+ "胡同",
105
+ "聪明",
106
+ "耽误",
107
+ "耽搁",
108
+ "耷拉",
109
+ "耳朵",
110
+ "老爷",
111
+ "老实",
112
+ "老婆",
113
+ "老头",
114
+ "老太",
115
+ "翻腾",
116
+ "罗嗦",
117
+ "罐头",
118
+ "编辑",
119
+ "结实",
120
+ "红火",
121
+ "累赘",
122
+ "糨糊",
123
+ "糊涂",
124
+ "精神",
125
+ "粮食",
126
+ "簸箕",
127
+ "篱笆",
128
+ "算计",
129
+ "算盘",
130
+ "答应",
131
+ "笤帚",
132
+ "笑语",
133
+ "笑话",
134
+ "窟窿",
135
+ "窝囊",
136
+ "窗户",
137
+ "稳当",
138
+ "稀罕",
139
+ "称呼",
140
+ "秧歌",
141
+ "秀气",
142
+ "秀才",
143
+ "福气",
144
+ "祖宗",
145
+ "砚台",
146
+ "码头",
147
+ "石榴",
148
+ "石头",
149
+ "石匠",
150
+ "知识",
151
+ "眼睛",
152
+ "眯缝",
153
+ "眨巴",
154
+ "眉毛",
155
+ "相声",
156
+ "盘算",
157
+ "白净",
158
+ "痢疾",
159
+ "痛快",
160
+ "疟疾",
161
+ "疙瘩",
162
+ "疏忽",
163
+ "畜生",
164
+ "生意",
165
+ "甘蔗",
166
+ "琵琶",
167
+ "琢磨",
168
+ "琉璃",
169
+ "玻璃",
170
+ "玫瑰",
171
+ "玄乎",
172
+ "狐狸",
173
+ "状元",
174
+ "特务",
175
+ "牲口",
176
+ "牙碜",
177
+ "牌楼",
178
+ "爽快",
179
+ "爱人",
180
+ "热闹",
181
+ "烧饼",
182
+ "烟筒",
183
+ "烂糊",
184
+ "点心",
185
+ "炊帚",
186
+ "灯笼",
187
+ "火候",
188
+ "漂亮",
189
+ "滑溜",
190
+ "溜达",
191
+ "温和",
192
+ "清楚",
193
+ "消息",
194
+ "浪头",
195
+ "活泼",
196
+ "比方",
197
+ "正经",
198
+ "欺负",
199
+ "模糊",
200
+ "槟榔",
201
+ "棺材",
202
+ "棒槌",
203
+ "棉花",
204
+ "核桃",
205
+ "栅栏",
206
+ "柴火",
207
+ "架势",
208
+ "枕头",
209
+ "枇杷",
210
+ "机灵",
211
+ "本事",
212
+ "木头",
213
+ "木匠",
214
+ "朋友",
215
+ "月饼",
216
+ "月亮",
217
+ "暖和",
218
+ "明白",
219
+ "时候",
220
+ "新鲜",
221
+ "故事",
222
+ "收拾",
223
+ "收成",
224
+ "提防",
225
+ "挖苦",
226
+ "挑剔",
227
+ "指甲",
228
+ "指头",
229
+ "拾掇",
230
+ "拳头",
231
+ "拨弄",
232
+ "招牌",
233
+ "招呼",
234
+ "抬举",
235
+ "护士",
236
+ "折腾",
237
+ "扫帚",
238
+ "打量",
239
+ "打算",
240
+ "打点",
241
+ "打扮",
242
+ "打听",
243
+ "打发",
244
+ "扎实",
245
+ "扁担",
246
+ "戒指",
247
+ "懒得",
248
+ "意识",
249
+ "意思",
250
+ "情形",
251
+ "悟性",
252
+ "怪物",
253
+ "思量",
254
+ "怎么",
255
+ "念头",
256
+ "念叨",
257
+ "快活",
258
+ "忙活",
259
+ "志气",
260
+ "心思",
261
+ "得罪",
262
+ "张罗",
263
+ "弟兄",
264
+ "开通",
265
+ "应酬",
266
+ "庄稼",
267
+ "干事",
268
+ "帮手",
269
+ "帐篷",
270
+ "希罕",
271
+ "师父",
272
+ "师傅",
273
+ "巴结",
274
+ "巴掌",
275
+ "差事",
276
+ "工夫",
277
+ "岁数",
278
+ "屁股",
279
+ "尾巴",
280
+ "少爷",
281
+ "小气",
282
+ "小伙",
283
+ "将就",
284
+ "对头",
285
+ "对付",
286
+ "寡妇",
287
+ "家伙",
288
+ "客气",
289
+ "实在",
290
+ "官司",
291
+ "学问",
292
+ "学生",
293
+ "字号",
294
+ "嫁妆",
295
+ "媳妇",
296
+ "媒人",
297
+ "婆家",
298
+ "娘家",
299
+ "委屈",
300
+ "姑娘",
301
+ "姐夫",
302
+ "妯娌",
303
+ "妥当",
304
+ "妖精",
305
+ "奴才",
306
+ "女婿",
307
+ "头发",
308
+ "太阳",
309
+ "大爷",
310
+ "大方",
311
+ "大意",
312
+ "大夫",
313
+ "多少",
314
+ "多么",
315
+ "外甥",
316
+ "壮实",
317
+ "地道",
318
+ "地方",
319
+ "在乎",
320
+ "困难",
321
+ "嘴巴",
322
+ "嘱咐",
323
+ "嘟囔",
324
+ "嘀咕",
325
+ "喜欢",
326
+ "喇嘛",
327
+ "喇叭",
328
+ "商量",
329
+ "唾沫",
330
+ "哑巴",
331
+ "哈欠",
332
+ "哆嗦",
333
+ "咳嗽",
334
+ "和尚",
335
+ "告诉",
336
+ "告示",
337
+ "含糊",
338
+ "吓唬",
339
+ "后头",
340
+ "名字",
341
+ "名堂",
342
+ "合同",
343
+ "吆喝",
344
+ "叫唤",
345
+ "口袋",
346
+ "厚道",
347
+ "厉害",
348
+ "千斤",
349
+ "包袱",
350
+ "包涵",
351
+ "匀称",
352
+ "勤快",
353
+ "动静",
354
+ "动弹",
355
+ "功夫",
356
+ "力气",
357
+ "前头",
358
+ "刺猬",
359
+ "刺激",
360
+ "别扭",
361
+ "利落",
362
+ "利索",
363
+ "利害",
364
+ "分析",
365
+ "出息",
366
+ "凑合",
367
+ "凉快",
368
+ "冷战",
369
+ "冤枉",
370
+ "冒失",
371
+ "养活",
372
+ "关系",
373
+ "先生",
374
+ "兄弟",
375
+ "便宜",
376
+ "使唤",
377
+ "佩服",
378
+ "作坊",
379
+ "体面",
380
+ "位置",
381
+ "似的",
382
+ "伙计",
383
+ "休息",
384
+ "什么",
385
+ "人家",
386
+ "亲戚",
387
+ "亲家",
388
+ "交情",
389
+ "云彩",
390
+ "事情",
391
+ "买卖",
392
+ "主意",
393
+ "丫头",
394
+ "丧气",
395
+ "两口",
396
+ "东西",
397
+ "东家",
398
+ "世故",
399
+ "不由",
400
+ "不在",
401
+ "下水",
402
+ "下巴",
403
+ "上头",
404
+ "上司",
405
+ "丈夫",
406
+ "丈人",
407
+ "一辈",
408
+ "那个",
409
+ "菩萨",
410
+ "父亲",
411
+ "母亲",
412
+ "咕噜",
413
+ "邋遢",
414
+ "费用",
415
+ "冤家",
416
+ "甜头",
417
+ "介绍",
418
+ "荒唐",
419
+ "大人",
420
+ "泥鳅",
421
+ "幸福",
422
+ "熟悉",
423
+ "计划",
424
+ "扑腾",
425
+ "蜡烛",
426
+ "姥爷",
427
+ "照顾",
428
+ "喉咙",
429
+ "吉他",
430
+ "弄堂",
431
+ "蚂蚱",
432
+ "凤凰",
433
+ "拖沓",
434
+ "寒碜",
435
+ "糟蹋",
436
+ "倒腾",
437
+ "报复",
438
+ "逻辑",
439
+ "盘缠",
440
+ "喽啰",
441
+ "牢骚",
442
+ "咖喱",
443
+ "扫把",
444
+ "惦记",
445
+ }
446
+ self.must_not_neural_tone_words = {
447
+ "男子",
448
+ "女子",
449
+ "分子",
450
+ "原子",
451
+ "量子",
452
+ "莲子",
453
+ "石子",
454
+ "瓜子",
455
+ "电子",
456
+ "人人",
457
+ "虎虎",
458
+ }
459
+ self.punc = ":,;。?!“”‘’':,;.?!"
460
+
461
+ # the meaning of jieba pos tag: https://blog.csdn.net/weixin_44174352/article/details/113731041
462
+ # e.g.
463
+ # word: "家里"
464
+ # pos: "s"
465
+ # finals: ['ia1', 'i3']
466
+ def _neural_sandhi(self, word: str, pos: str, finals: List[str]) -> List[str]:
467
+ # reduplication words for n. and v. e.g. 奶奶, 试试, 旺旺
468
+ for j, item in enumerate(word):
469
+ if (
470
+ j - 1 >= 0
471
+ and item == word[j - 1]
472
+ and pos[0] in {"n", "v", "a"}
473
+ and word not in self.must_not_neural_tone_words
474
+ ):
475
+ finals[j] = finals[j][:-1] + "5"
476
+ ge_idx = word.find("个")
477
+ if len(word) >= 1 and word[-1] in "吧呢啊呐噻嘛吖嗨呐哦哒额滴哩哟喽啰耶喔诶":
478
+ finals[-1] = finals[-1][:-1] + "5"
479
+ elif len(word) >= 1 and word[-1] in "的地得":
480
+ finals[-1] = finals[-1][:-1] + "5"
481
+ # e.g. 走了, 看着, 去过
482
+ # elif len(word) == 1 and word in "了着过" and pos in {"ul", "uz", "ug"}:
483
+ # finals[-1] = finals[-1][:-1] + "5"
484
+ elif (
485
+ len(word) > 1
486
+ and word[-1] in "们子"
487
+ and pos in {"r", "n"}
488
+ and word not in self.must_not_neural_tone_words
489
+ ):
490
+ finals[-1] = finals[-1][:-1] + "5"
491
+ # e.g. 桌上, 地下, 家里
492
+ elif len(word) > 1 and word[-1] in "上下里" and pos in {"s", "l", "f"}:
493
+ finals[-1] = finals[-1][:-1] + "5"
494
+ # e.g. 上来, 下去
495
+ elif len(word) > 1 and word[-1] in "来去" and word[-2] in "上下进出回过起开":
496
+ finals[-1] = finals[-1][:-1] + "5"
497
+ # 个做量词
498
+ elif (
499
+ ge_idx >= 1
500
+ and (word[ge_idx - 1].isnumeric() or word[ge_idx - 1] in "几有两半多各整每做是")
501
+ ) or word == "个":
502
+ finals[ge_idx] = finals[ge_idx][:-1] + "5"
503
+ else:
504
+ if (
505
+ word in self.must_neural_tone_words
506
+ or word[-2:] in self.must_neural_tone_words
507
+ ):
508
+ finals[-1] = finals[-1][:-1] + "5"
509
+
510
+ word_list = self._split_word(word)
511
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
512
+ for i, word in enumerate(word_list):
513
+ # conventional neural in Chinese
514
+ if (
515
+ word in self.must_neural_tone_words
516
+ or word[-2:] in self.must_neural_tone_words
517
+ ):
518
+ finals_list[i][-1] = finals_list[i][-1][:-1] + "5"
519
+ finals = sum(finals_list, [])
520
+ return finals
521
+
522
+ def _bu_sandhi(self, word: str, finals: List[str]) -> List[str]:
523
+ # e.g. 看不懂
524
+ if len(word) == 3 and word[1] == "不":
525
+ finals[1] = finals[1][:-1] + "5"
526
+ else:
527
+ for i, char in enumerate(word):
528
+ # "不" before tone4 should be bu2, e.g. 不怕
529
+ if char == "不" and i + 1 < len(word) and finals[i + 1][-1] == "4":
530
+ finals[i] = finals[i][:-1] + "2"
531
+ return finals
532
+
533
+ def _yi_sandhi(self, word: str, finals: List[str]) -> List[str]:
534
+ # "一" in number sequences, e.g. 一零零, 二一零
535
+ if word.find("一") != -1 and all(
536
+ [item.isnumeric() for item in word if item != "一"]
537
+ ):
538
+ return finals
539
+ # "一" between reduplication words should be yi5, e.g. 看一看
540
+ elif len(word) == 3 and word[1] == "一" and word[0] == word[-1]:
541
+ finals[1] = finals[1][:-1] + "5"
542
+ # when "一" is ordinal word, it should be yi1
543
+ elif word.startswith("第一"):
544
+ finals[1] = finals[1][:-1] + "1"
545
+ else:
546
+ for i, char in enumerate(word):
547
+ if char == "一" and i + 1 < len(word):
548
+ # "一" before tone4 should be yi2, e.g. 一段
549
+ if finals[i + 1][-1] == "4":
550
+ finals[i] = finals[i][:-1] + "2"
551
+ # "一" before non-tone4 should be yi4, e.g. 一天
552
+ else:
553
+ # "一" 后面如果是标点,还读一声
554
+ if word[i + 1] not in self.punc:
555
+ finals[i] = finals[i][:-1] + "4"
556
+ return finals
557
+
558
+ def _split_word(self, word: str) -> List[str]:
559
+ word_list = jieba.cut_for_search(word)
560
+ word_list = sorted(word_list, key=lambda i: len(i), reverse=False)
561
+ first_subword = word_list[0]
562
+ first_begin_idx = word.find(first_subword)
563
+ if first_begin_idx == 0:
564
+ second_subword = word[len(first_subword) :]
565
+ new_word_list = [first_subword, second_subword]
566
+ else:
567
+ second_subword = word[: -len(first_subword)]
568
+ new_word_list = [second_subword, first_subword]
569
+ return new_word_list
570
+
571
+ def _three_sandhi(self, word: str, finals: List[str]) -> List[str]:
572
+ if len(word) == 2 and self._all_tone_three(finals):
573
+ finals[0] = finals[0][:-1] + "2"
574
+ elif len(word) == 3:
575
+ word_list = self._split_word(word)
576
+ if self._all_tone_three(finals):
577
+ # disyllabic + monosyllabic, e.g. 蒙古/包
578
+ if len(word_list[0]) == 2:
579
+ finals[0] = finals[0][:-1] + "2"
580
+ finals[1] = finals[1][:-1] + "2"
581
+ # monosyllabic + disyllabic, e.g. 纸/老虎
582
+ elif len(word_list[0]) == 1:
583
+ finals[1] = finals[1][:-1] + "2"
584
+ else:
585
+ finals_list = [finals[: len(word_list[0])], finals[len(word_list[0]) :]]
586
+ if len(finals_list) == 2:
587
+ for i, sub in enumerate(finals_list):
588
+ # e.g. 所有/人
589
+ if self._all_tone_three(sub) and len(sub) == 2:
590
+ finals_list[i][0] = finals_list[i][0][:-1] + "2"
591
+ # e.g. 好/喜欢
592
+ elif (
593
+ i == 1
594
+ and not self._all_tone_three(sub)
595
+ and finals_list[i][0][-1] == "3"
596
+ and finals_list[0][-1][-1] == "3"
597
+ ):
598
+ finals_list[0][-1] = finals_list[0][-1][:-1] + "2"
599
+ finals = sum(finals_list, [])
600
+ # split idiom into two words who's length is 2
601
+ elif len(word) == 4:
602
+ finals_list = [finals[:2], finals[2:]]
603
+ finals = []
604
+ for sub in finals_list:
605
+ if self._all_tone_three(sub):
606
+ sub[0] = sub[0][:-1] + "2"
607
+ finals += sub
608
+
609
+ return finals
610
+
611
+ def _all_tone_three(self, finals: List[str]) -> bool:
612
+ return all(x[-1] == "3" for x in finals)
613
+
614
+ # merge "不" and the word behind it
615
+ # if don't merge, "不" sometimes appears alone according to jieba, which may occur sandhi error
616
+ def _merge_bu(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
617
+ new_seg = []
618
+ last_word = ""
619
+ for word, pos in seg:
620
+ if last_word == "不":
621
+ word = last_word + word
622
+ if word != "不":
623
+ new_seg.append((word, pos))
624
+ last_word = word[:]
625
+ if last_word == "不":
626
+ new_seg.append((last_word, "d"))
627
+ last_word = ""
628
+ return new_seg
629
+
630
+ # function 1: merge "一" and reduplication words in it's left and right, e.g. "听","一","听" ->"听一听"
631
+ # function 2: merge single "一" and the word behind it
632
+ # if don't merge, "一" sometimes appears alone according to jieba, which may occur sandhi error
633
+ # e.g.
634
+ # input seg: [('听', 'v'), ('一', 'm'), ('听', 'v')]
635
+ # output seg: [['听一听', 'v']]
636
+ def _merge_yi(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
637
+ new_seg = [] * len(seg)
638
+ # function 1
639
+ i = 0
640
+ while i < len(seg):
641
+ word, pos = seg[i]
642
+ if (
643
+ i - 1 >= 0
644
+ and word == "一"
645
+ and i + 1 < len(seg)
646
+ and seg[i - 1][0] == seg[i + 1][0]
647
+ and seg[i - 1][1] == "v"
648
+ ):
649
+ new_seg[i - 1][0] = new_seg[i - 1][0] + "一" + new_seg[i - 1][0]
650
+ i += 2
651
+ else:
652
+ if (
653
+ i - 2 >= 0
654
+ and seg[i - 1][0] == "一"
655
+ and seg[i - 2][0] == word
656
+ and pos == "v"
657
+ ):
658
+ continue
659
+ else:
660
+ new_seg.append([word, pos])
661
+ i += 1
662
+ seg = [i for i in new_seg if len(i) > 0]
663
+ new_seg = []
664
+ # function 2
665
+ for i, (word, pos) in enumerate(seg):
666
+ if new_seg and new_seg[-1][0] == "一":
667
+ new_seg[-1][0] = new_seg[-1][0] + word
668
+ else:
669
+ new_seg.append([word, pos])
670
+ return new_seg
671
+
672
+ # the first and the second words are all_tone_three
673
+ def _merge_continuous_three_tones(
674
+ self, seg: List[Tuple[str, str]]
675
+ ) -> List[Tuple[str, str]]:
676
+ new_seg = []
677
+ sub_finals_list = [
678
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
679
+ for (word, pos) in seg
680
+ ]
681
+ assert len(sub_finals_list) == len(seg)
682
+ merge_last = [False] * len(seg)
683
+ for i, (word, pos) in enumerate(seg):
684
+ if (
685
+ i - 1 >= 0
686
+ and self._all_tone_three(sub_finals_list[i - 1])
687
+ and self._all_tone_three(sub_finals_list[i])
688
+ and not merge_last[i - 1]
689
+ ):
690
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
691
+ if (
692
+ not self._is_reduplication(seg[i - 1][0])
693
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
694
+ ):
695
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
696
+ merge_last[i] = True
697
+ else:
698
+ new_seg.append([word, pos])
699
+ else:
700
+ new_seg.append([word, pos])
701
+
702
+ return new_seg
703
+
704
+ def _is_reduplication(self, word: str) -> bool:
705
+ return len(word) == 2 and word[0] == word[1]
706
+
707
+ # the last char of first word and the first char of second word is tone_three
708
+ def _merge_continuous_three_tones_2(
709
+ self, seg: List[Tuple[str, str]]
710
+ ) -> List[Tuple[str, str]]:
711
+ new_seg = []
712
+ sub_finals_list = [
713
+ lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
714
+ for (word, pos) in seg
715
+ ]
716
+ assert len(sub_finals_list) == len(seg)
717
+ merge_last = [False] * len(seg)
718
+ for i, (word, pos) in enumerate(seg):
719
+ if (
720
+ i - 1 >= 0
721
+ and sub_finals_list[i - 1][-1][-1] == "3"
722
+ and sub_finals_list[i][0][-1] == "3"
723
+ and not merge_last[i - 1]
724
+ ):
725
+ # if the last word is reduplication, not merge, because reduplication need to be _neural_sandhi
726
+ if (
727
+ not self._is_reduplication(seg[i - 1][0])
728
+ and len(seg[i - 1][0]) + len(seg[i][0]) <= 3
729
+ ):
730
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
731
+ merge_last[i] = True
732
+ else:
733
+ new_seg.append([word, pos])
734
+ else:
735
+ new_seg.append([word, pos])
736
+ return new_seg
737
+
738
+ def _merge_er(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
739
+ new_seg = []
740
+ for i, (word, pos) in enumerate(seg):
741
+ if i - 1 >= 0 and word == "儿" and seg[i - 1][0] != "#":
742
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
743
+ else:
744
+ new_seg.append([word, pos])
745
+ return new_seg
746
+
747
+ def _merge_reduplication(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
748
+ new_seg = []
749
+ for i, (word, pos) in enumerate(seg):
750
+ if new_seg and word == new_seg[-1][0]:
751
+ new_seg[-1][0] = new_seg[-1][0] + seg[i][0]
752
+ else:
753
+ new_seg.append([word, pos])
754
+ return new_seg
755
+
756
+ def pre_merge_for_modify(self, seg: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
757
+ seg = self._merge_bu(seg)
758
+ try:
759
+ seg = self._merge_yi(seg)
760
+ except:
761
+ print("_merge_yi failed")
762
+ seg = self._merge_reduplication(seg)
763
+ seg = self._merge_continuous_three_tones(seg)
764
+ seg = self._merge_continuous_three_tones_2(seg)
765
+ seg = self._merge_er(seg)
766
+ return seg
767
+
768
+ def modified_tone(self, word: str, pos: str, finals: List[str]) -> List[str]:
769
+ finals = self._bu_sandhi(word, finals)
770
+ finals = self._yi_sandhi(word, finals)
771
+ finals = self._neural_sandhi(word, pos, finals)
772
+ finals = self._three_sandhi(word, finals)
773
+ return finals
tools/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ 工具包
3
+ """
tools/classify_language.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex as re
2
+
3
+ try:
4
+ from config import config
5
+
6
+ LANGUAGE_IDENTIFICATION_LIBRARY = (
7
+ config.webui_config.language_identification_library
8
+ )
9
+ except:
10
+ LANGUAGE_IDENTIFICATION_LIBRARY = "langid"
11
+
12
+ module = LANGUAGE_IDENTIFICATION_LIBRARY.lower()
13
+
14
+ langid_languages = [
15
+ "af",
16
+ "am",
17
+ "an",
18
+ "ar",
19
+ "as",
20
+ "az",
21
+ "be",
22
+ "bg",
23
+ "bn",
24
+ "br",
25
+ "bs",
26
+ "ca",
27
+ "cs",
28
+ "cy",
29
+ "da",
30
+ "de",
31
+ "dz",
32
+ "el",
33
+ "en",
34
+ "eo",
35
+ "es",
36
+ "et",
37
+ "eu",
38
+ "fa",
39
+ "fi",
40
+ "fo",
41
+ "fr",
42
+ "ga",
43
+ "gl",
44
+ "gu",
45
+ "he",
46
+ "hi",
47
+ "hr",
48
+ "ht",
49
+ "hu",
50
+ "hy",
51
+ "id",
52
+ "is",
53
+ "it",
54
+ "ja",
55
+ "jv",
56
+ "ka",
57
+ "kk",
58
+ "km",
59
+ "kn",
60
+ "ko",
61
+ "ku",
62
+ "ky",
63
+ "la",
64
+ "lb",
65
+ "lo",
66
+ "lt",
67
+ "lv",
68
+ "mg",
69
+ "mk",
70
+ "ml",
71
+ "mn",
72
+ "mr",
73
+ "ms",
74
+ "mt",
75
+ "nb",
76
+ "ne",
77
+ "nl",
78
+ "nn",
79
+ "no",
80
+ "oc",
81
+ "or",
82
+ "pa",
83
+ "pl",
84
+ "ps",
85
+ "pt",
86
+ "qu",
87
+ "ro",
88
+ "ru",
89
+ "rw",
90
+ "se",
91
+ "si",
92
+ "sk",
93
+ "sl",
94
+ "sq",
95
+ "sr",
96
+ "sv",
97
+ "sw",
98
+ "ta",
99
+ "te",
100
+ "th",
101
+ "tl",
102
+ "tr",
103
+ "ug",
104
+ "uk",
105
+ "ur",
106
+ "vi",
107
+ "vo",
108
+ "wa",
109
+ "xh",
110
+ "zh",
111
+ "zu",
112
+ ]
113
+
114
+
115
+ def classify_language(text: str, target_languages: list = None) -> str:
116
+ if module == "fastlid" or module == "fasttext":
117
+ from fastlid import fastlid, supported_langs
118
+
119
+ classifier = fastlid
120
+ if target_languages != None:
121
+ target_languages = [
122
+ lang for lang in target_languages if lang in supported_langs
123
+ ]
124
+ fastlid.set_languages = target_languages
125
+ elif module == "langid":
126
+ import langid
127
+
128
+ classifier = langid.classify
129
+ if target_languages != None:
130
+ target_languages = [
131
+ lang for lang in target_languages if lang in langid_languages
132
+ ]
133
+ langid.set_languages(target_languages)
134
+ else:
135
+ raise ValueError(f"Wrong module {module}")
136
+
137
+ lang = classifier(text)[0]
138
+
139
+ return lang
140
+
141
+
142
+ def classify_zh_ja(text: str) -> str:
143
+ for idx, char in enumerate(text):
144
+ unicode_val = ord(char)
145
+
146
+ # 检测日语字符
147
+ if 0x3040 <= unicode_val <= 0x309F or 0x30A0 <= unicode_val <= 0x30FF:
148
+ return "ja"
149
+
150
+ # 检测汉字字符
151
+ if 0x4E00 <= unicode_val <= 0x9FFF:
152
+ # 检查周围的字符
153
+ next_char = text[idx + 1] if idx + 1 < len(text) else None
154
+
155
+ if next_char and (
156
+ 0x3040 <= ord(next_char) <= 0x309F or 0x30A0 <= ord(next_char) <= 0x30FF
157
+ ):
158
+ return "ja"
159
+
160
+ return "zh"
161
+
162
+
163
+ def split_alpha_nonalpha(text, mode=1):
164
+ if mode == 1:
165
+ pattern = r"(?<=[\u4e00-\u9fff\u3040-\u30FF\d\s])(?=[\p{Latin}])|(?<=[\p{Latin}\s])(?=[\u4e00-\u9fff\u3040-\u30FF\d])"
166
+ elif mode == 2:
167
+ pattern = r"(?<=[\u4e00-\u9fff\u3040-\u30FF\s])(?=[\p{Latin}\d])|(?<=[\p{Latin}\d\s])(?=[\u4e00-\u9fff\u3040-\u30FF])"
168
+ else:
169
+ raise ValueError("Invalid mode. Supported modes are 1 and 2.")
170
+
171
+ return re.split(pattern, text)
172
+
173
+
174
+ if __name__ == "__main__":
175
+ text = "这是一个测试文本"
176
+ print(classify_language(text))
177
+ print(classify_zh_ja(text)) # "zh"
178
+
179
+ text = "これはテストテキストです"
180
+ print(classify_language(text))
181
+ print(classify_zh_ja(text)) # "ja"
182
+
183
+ text = "vits和Bert-VITS2是tts模型。花费3days.花费3天。Take 3 days"
184
+
185
+ print(split_alpha_nonalpha(text, mode=1))
186
+ # output: ['vits', '和', 'Bert-VITS', '2是', 'tts', '模型。花费3', 'days.花费3天。Take 3 days']
187
+
188
+ print(split_alpha_nonalpha(text, mode=2))
189
+ # output: ['vits', '和', 'Bert-VITS2', '是', 'tts', '模型。花费', '3days.花费', '3', '天。Take 3 days']
190
+
191
+ text = "vits 和 Bert-VITS2 是 tts 模型。花费3days.花费3天。Take 3 days"
192
+ print(split_alpha_nonalpha(text, mode=1))
193
+ # output: ['vits ', '和 ', 'Bert-VITS', '2 ', '是 ', 'tts ', '模型。花费3', 'days.花费3天。Take ', '3 ', 'days']
194
+
195
+ text = "vits 和 Bert-VITS2 是 tts 模型。花费3days.花费3天。Take 3 days"
196
+ print(split_alpha_nonalpha(text, mode=2))
197
+ # output: ['vits ', '和 ', 'Bert-VITS2 ', '是 ', 'tts ', '模型。花费', '3days.花费', '3', '天。Take ', '3 ', 'days']
tools/sentence.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import regex as re
4
+
5
+ from tools.classify_language import classify_language, split_alpha_nonalpha
6
+
7
+
8
+ def check_is_none(item) -> bool:
9
+ """none -> True, not none -> False"""
10
+ return (
11
+ item is None
12
+ or (isinstance(item, str) and str(item).isspace())
13
+ or str(item) == ""
14
+ )
15
+
16
+
17
+ def markup_language(text: str, target_languages: list = None) -> str:
18
+ pattern = (
19
+ r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`"
20
+ r"\!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」"
21
+ r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\‛\“\”\„\‟…‧﹏.]+"
22
+ )
23
+ sentences = re.split(pattern, text)
24
+
25
+ pre_lang = ""
26
+ p = 0
27
+
28
+ if target_languages is not None:
29
+ sorted_target_languages = sorted(target_languages)
30
+ if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]:
31
+ new_sentences = []
32
+ for sentence in sentences:
33
+ new_sentences.extend(split_alpha_nonalpha(sentence))
34
+ sentences = new_sentences
35
+
36
+ for sentence in sentences:
37
+ if check_is_none(sentence):
38
+ continue
39
+
40
+ lang = classify_language(sentence, target_languages)
41
+
42
+ if pre_lang == "":
43
+ text = text[:p] + text[p:].replace(
44
+ sentence, f"[{lang.upper()}]{sentence}", 1
45
+ )
46
+ p += len(f"[{lang.upper()}]")
47
+ elif pre_lang != lang:
48
+ text = text[:p] + text[p:].replace(
49
+ sentence, f"[{pre_lang.upper()}][{lang.upper()}]{sentence}", 1
50
+ )
51
+ p += len(f"[{pre_lang.upper()}][{lang.upper()}]")
52
+ pre_lang = lang
53
+ p += text[p:].index(sentence) + len(sentence)
54
+ text += f"[{pre_lang.upper()}]"
55
+
56
+ return text
57
+
58
+
59
+ def split_by_language(text: str, target_languages: list = None) -> list:
60
+ pattern = (
61
+ r"[\!\"\#\$\%\&\'\(\)\*\+\,\-\.\/\:\;\<\>\=\?\@\[\]\{\}\\\\\^\_\`"
62
+ r"\!?\。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」"
63
+ r"『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘\'\‛\“\”\„\‟…‧﹏.]+"
64
+ )
65
+ sentences = re.split(pattern, text)
66
+
67
+ pre_lang = ""
68
+ start = 0
69
+ end = 0
70
+ sentences_list = []
71
+
72
+ if target_languages is not None:
73
+ sorted_target_languages = sorted(target_languages)
74
+ if sorted_target_languages in [["en", "zh"], ["en", "ja"], ["en", "ja", "zh"]]:
75
+ new_sentences = []
76
+ for sentence in sentences:
77
+ new_sentences.extend(split_alpha_nonalpha(sentence))
78
+ sentences = new_sentences
79
+
80
+ for sentence in sentences:
81
+ if check_is_none(sentence):
82
+ continue
83
+
84
+ lang = classify_language(sentence, target_languages)
85
+
86
+ end += text[end:].index(sentence)
87
+ if pre_lang != "" and pre_lang != lang:
88
+ sentences_list.append((text[start:end], pre_lang))
89
+ start = end
90
+ end += len(sentence)
91
+ pre_lang = lang
92
+ sentences_list.append((text[start:], pre_lang))
93
+
94
+ return sentences_list
95
+
96
+
97
+ def sentence_split(text: str, max: int) -> list:
98
+ pattern = r"[!(),—+\-.:;??。,、;:]+"
99
+ sentences = re.split(pattern, text)
100
+ discarded_chars = re.findall(pattern, text)
101
+
102
+ sentences_list, count, p = [], 0, 0
103
+
104
+ # 按被分割的符号遍历
105
+ for i, discarded_chars in enumerate(discarded_chars):
106
+ count += len(sentences[i]) + len(discarded_chars)
107
+ if count >= max:
108
+ sentences_list.append(text[p : p + count].strip())
109
+ p += count
110
+ count = 0
111
+
112
+ # 加入最后剩余的文本
113
+ if p < len(text):
114
+ sentences_list.append(text[p:])
115
+
116
+ return sentences_list
117
+
118
+
119
+ def sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None):
120
+ # 如果该speaker只支持一种语言
121
+ if speaker_lang is not None and len(speaker_lang) == 1:
122
+ if lang.upper() not in ["AUTO", "MIX"] and lang.lower() != speaker_lang[0]:
123
+ logging.debug(
124
+ f'lang "{lang}" is not in speaker_lang {speaker_lang},automatically set lang={speaker_lang[0]}'
125
+ )
126
+ lang = speaker_lang[0]
127
+
128
+ sentences_list = []
129
+ if lang.upper() != "MIX":
130
+ if max <= 0:
131
+ sentences_list.append(
132
+ markup_language(text, speaker_lang)
133
+ if lang.upper() == "AUTO"
134
+ else f"[{lang.upper()}]{text}[{lang.upper()}]"
135
+ )
136
+ else:
137
+ for i in sentence_split(text, max):
138
+ if check_is_none(i):
139
+ continue
140
+ sentences_list.append(
141
+ markup_language(i, speaker_lang)
142
+ if lang.upper() == "AUTO"
143
+ else f"[{lang.upper()}]{i}[{lang.upper()}]"
144
+ )
145
+ else:
146
+ sentences_list.append(text)
147
+
148
+ for i in sentences_list:
149
+ logging.debug(i)
150
+
151
+ return sentences_list
152
+
153
+
154
+ if __name__ == "__main__":
155
+ text = "这几天心里颇不宁静。今晚在院子里坐着乘凉,忽然想起日日走过的荷塘,在这满月的光里,总该另有一番样子吧。月亮渐渐地升高了,墙外马路上孩子们的欢笑,已经听不见了;妻在屋里拍着闰儿,迷迷糊糊地哼着眠歌。我悄悄地披了大衫,带上门出去。"
156
+ print(markup_language(text, target_languages=None))
157
+ print(sentence_split(text, max=50))
158
+ print(sentence_split_and_markup(text, max=50, lang="auto", speaker_lang=None))
159
+
160
+ text = "你好,这是一段用来测试自动标注的文本。こんにちは,これは自動ラベリングのテスト用テキストです.Hello, this is a piece of text to test autotagging.你好!今天我们要介绍VITS项目,其重点是使用了GAN Duration predictor和transformer flow,并且接入了Bert模型来提升韵律。Bert embedding会在稍后介绍。"
161
+ print(split_by_language(text, ["zh", "ja", "en"]))
162
+
163
+ text = "vits和Bert-VITS2是tts模型。花费3days.花费3天。Take 3 days"
164
+
165
+ print(split_by_language(text, ["zh", "ja", "en"]))
166
+ # output: [('vits', 'en'), ('和', 'ja'), ('Bert-VITS', 'en'), ('2是', 'zh'), ('tts', 'en'), ('模型。花费3', 'zh'), ('days.', 'en'), ('花费3天。', 'zh'), ('Take 3 days', 'en')]
167
+
168
+ print(split_by_language(text, ["zh", "en"]))
169
+ # output: [('vits', 'en'), ('和', 'zh'), ('Bert-VITS', 'en'), ('2是', 'zh'), ('tts', 'en'), ('模型。花费3', 'zh'), ('days.', 'en'), ('花费3天。', 'zh'), ('Take 3 days', 'en')]
170
+
171
+ text = "vits 和 Bert-VITS2 是 tts 模型。花费 3 days. 花费 3天。Take 3 days"
172
+ print(split_by_language(text, ["zh", "en"]))
173
+ # output: [('vits ', 'en'), ('和 ', 'zh'), ('Bert-VITS2 ', 'en'), ('是 ', 'zh'), ('tts ', 'en'), ('模型。花费 ', 'zh'), ('3 days. ', 'en'), ('花费 3天。', 'zh'), ('Take 3 days', 'en')]
tools/translate.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 翻译api
3
+ """
4
+ from config import config
5
+
6
+ import random
7
+ import hashlib
8
+ import requests
9
+
10
+
11
+ def translate(Sentence: str, to_Language: str = "jp", from_Language: str = ""):
12
+ """
13
+ :param Sentence: 待翻译语句
14
+ :param from_Language: 待翻译语句语言
15
+ :param to_Language: 目标语言
16
+ :return: 翻译后语句 出错时返回None
17
+
18
+ 常见语言代码:中文 zh 英语 en 日语 jp
19
+ """
20
+ appid = config.translate_config.app_key
21
+ key = config.translate_config.secret_key
22
+ if appid == "" or key == "":
23
+ return "请开发者在config.yml中配置app_key与secret_key"
24
+ url = "https://fanyi-api.baidu.com/api/trans/vip/translate"
25
+ texts = Sentence.splitlines()
26
+ outTexts = []
27
+ for t in texts:
28
+ if t != "":
29
+ # 签名计算 参考文档 https://api.fanyi.baidu.com/product/113
30
+ salt = str(random.randint(1, 100000))
31
+ signString = appid + t + salt + key
32
+ hs = hashlib.md5()
33
+ hs.update(signString.encode("utf-8"))
34
+ signString = hs.hexdigest()
35
+ if from_Language == "":
36
+ from_Language = "auto"
37
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
38
+ payload = {
39
+ "q": t,
40
+ "from": from_Language,
41
+ "to": to_Language,
42
+ "appid": appid,
43
+ "salt": salt,
44
+ "sign": signString,
45
+ }
46
+ # 发送请求
47
+ try:
48
+ response = requests.post(
49
+ url=url, data=payload, headers=headers, timeout=3
50
+ )
51
+ response = response.json()
52
+ if "trans_result" in response.keys():
53
+ result = response["trans_result"][0]
54
+ if "dst" in result.keys():
55
+ dst = result["dst"]
56
+ outTexts.append(dst)
57
+ except Exception:
58
+ return Sentence
59
+ else:
60
+ outTexts.append(t)
61
+ return "\n".join(outTexts)
utils.py CHANGED
@@ -13,7 +13,7 @@ from safetensors import safe_open
13
  from safetensors.torch import save_file
14
  from scipy.io.wavfile import read
15
 
16
- from tools.log import logger
17
 
18
  MATPLOTLIB_FLAG = False
19
 
@@ -189,10 +189,11 @@ def summarize(
189
 
190
 
191
  def is_resuming(dir_path):
 
192
  g_list = glob.glob(os.path.join(dir_path, "G_*.pth"))
193
- d_list = glob.glob(os.path.join(dir_path, "D_*.pth"))
194
- dur_list = glob.glob(os.path.join(dir_path, "DUR_*.pth"))
195
- return len(g_list) > 0 and len(d_list) > 0 and len(dur_list) > 0
196
 
197
 
198
  def latest_checkpoint_path(dir_path, regex="G_*.pth"):
@@ -348,7 +349,7 @@ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_tim
348
  ]
349
 
350
  def del_info(fn):
351
- return logger.info(f".. Free up space by deleting ckpt {fn}")
352
 
353
  def del_routine(x):
354
  return [os.remove(x), del_info(x)]
 
13
  from safetensors.torch import save_file
14
  from scipy.io.wavfile import read
15
 
16
+ from common.log import logger
17
 
18
  MATPLOTLIB_FLAG = False
19
 
 
189
 
190
 
191
  def is_resuming(dir_path):
192
+ # JP-ExtraバージョンではDURがなくWDがあったり変わるため、Gのみで判断する
193
  g_list = glob.glob(os.path.join(dir_path, "G_*.pth"))
194
+ # d_list = glob.glob(os.path.join(dir_path, "D_*.pth"))
195
+ # dur_list = glob.glob(os.path.join(dir_path, "DUR_*.pth"))
196
+ return len(g_list) > 0
197
 
198
 
199
  def latest_checkpoint_path(dir_path, regex="G_*.pth"):
 
349
  ]
350
 
351
  def del_info(fn):
352
+ return logger.info(f"Free up space by deleting ckpt {fn}")
353
 
354
  def del_routine(x):
355
  return [os.remove(x), del_info(x)]