Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import uuid
|
|
7 |
from io import StringIO
|
8 |
|
9 |
import gradio as gr
|
|
|
10 |
import torch
|
11 |
import torchaudio
|
12 |
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
|
@@ -14,25 +15,25 @@ from TTS.tts.configs.xtts_config import XttsConfig
|
|
14 |
from TTS.tts.models.xtts import Xtts
|
15 |
from vinorm import TTSnorm
|
16 |
|
17 |
-
#
|
18 |
os.system("python -m unidic download")
|
19 |
|
20 |
-
# Hugging Face token and API setup
|
21 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
22 |
api = HfApi(token=HF_TOKEN)
|
23 |
|
24 |
-
#
|
|
|
25 |
checkpoint_dir = "model/"
|
|
|
|
|
|
|
26 |
os.makedirs(checkpoint_dir, exist_ok=True)
|
27 |
|
28 |
-
# Required files for the model
|
29 |
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
|
30 |
-
|
31 |
-
# Download model and configurations if not present
|
32 |
files_in_dir = os.listdir(checkpoint_dir)
|
33 |
if not all(file in files_in_dir for file in required_files):
|
34 |
snapshot_download(
|
35 |
-
repo_id=
|
36 |
repo_type="model",
|
37 |
local_dir=checkpoint_dir,
|
38 |
)
|
@@ -42,23 +43,21 @@ if not all(file in files_in_dir for file in required_files):
|
|
42 |
local_dir=checkpoint_dir,
|
43 |
)
|
44 |
|
45 |
-
# Initialize XTTS model from configuration
|
46 |
xtts_config = os.path.join(checkpoint_dir, "config.json")
|
47 |
config = XttsConfig()
|
48 |
config.load_json(xtts_config)
|
49 |
MODEL = Xtts.init_from_config(config)
|
50 |
MODEL.load_checkpoint(
|
51 |
-
config, checkpoint_dir=checkpoint_dir, use_deepspeed=
|
52 |
)
|
53 |
if torch.cuda.is_available():
|
54 |
MODEL.cuda()
|
55 |
|
56 |
-
# Supported languages for TTS
|
57 |
supported_languages = config.languages
|
58 |
-
if "vi"
|
59 |
supported_languages.append("vi")
|
60 |
|
61 |
-
|
62 |
def normalize_vietnamese_text(text):
|
63 |
text = (
|
64 |
TTSnorm(text, unknown=False, lower=False, rule=True)
|
@@ -74,8 +73,9 @@ def normalize_vietnamese_text(text):
|
|
74 |
)
|
75 |
return text
|
76 |
|
77 |
-
|
78 |
def calculate_keep_len(text, lang):
|
|
|
79 |
if lang in ["ja", "zh-cn"]:
|
80 |
return -1
|
81 |
|
@@ -88,30 +88,65 @@ def calculate_keep_len(text, lang):
|
|
88 |
return 13000 * word_count + 2000 * num_punct
|
89 |
return -1
|
90 |
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if language not in supported_languages:
|
94 |
-
|
|
|
|
|
|
|
|
|
95 |
|
96 |
speaker_wav = audio_file_pth
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
try:
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
# Normalize Vietnamese text if specified
|
103 |
if normalize_text and language == "vi":
|
104 |
prompt = normalize_vietnamese_text(prompt)
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
audio_path=speaker_wav,
|
109 |
-
gpt_cond_len=30,
|
110 |
-
gpt_cond_chunk_len=4,
|
111 |
-
max_ref_length=60,
|
112 |
-
)
|
113 |
-
|
114 |
-
# Perform inference to generate audio
|
115 |
out = MODEL.inference(
|
116 |
prompt,
|
117 |
language,
|
@@ -121,50 +156,166 @@ def predict(prompt, language, audio_file_pth, normalize_text=True):
|
|
121 |
temperature=0.75,
|
122 |
enable_text_splitting=True,
|
123 |
)
|
124 |
-
|
125 |
-
# Calculate inference time and real-time factor
|
126 |
inference_time = time.time() - t0
|
|
|
|
|
|
|
|
|
127 |
real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
|
|
|
|
|
128 |
|
129 |
-
#
|
130 |
keep_len = calculate_keep_len(prompt, language)
|
131 |
out["wav"] = out["wav"][:keep_len]
|
132 |
|
133 |
-
# Save generated audio
|
134 |
torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
135 |
|
136 |
except RuntimeError as e:
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
demo.
|
|
|
|
7 |
from io import StringIO
|
8 |
|
9 |
import gradio as gr
|
10 |
+
import spaces
|
11 |
import torch
|
12 |
import torchaudio
|
13 |
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
|
|
|
15 |
from TTS.tts.models.xtts import Xtts
|
16 |
from vinorm import TTSnorm
|
17 |
|
18 |
+
# download for mecab
|
19 |
os.system("python -m unidic download")
|
20 |
|
|
|
21 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
22 |
api = HfApi(token=HF_TOKEN)
|
23 |
|
24 |
+
# This will trigger downloading model
|
25 |
+
print("Downloading if not downloaded viXTTS")
|
26 |
checkpoint_dir = "model/"
|
27 |
+
repo_id = "capleaf/viXTTS"
|
28 |
+
use_deepspeed = False
|
29 |
+
|
30 |
os.makedirs(checkpoint_dir, exist_ok=True)
|
31 |
|
|
|
32 |
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
|
|
|
|
|
33 |
files_in_dir = os.listdir(checkpoint_dir)
|
34 |
if not all(file in files_in_dir for file in required_files):
|
35 |
snapshot_download(
|
36 |
+
repo_id=repo_id,
|
37 |
repo_type="model",
|
38 |
local_dir=checkpoint_dir,
|
39 |
)
|
|
|
43 |
local_dir=checkpoint_dir,
|
44 |
)
|
45 |
|
|
|
46 |
xtts_config = os.path.join(checkpoint_dir, "config.json")
|
47 |
config = XttsConfig()
|
48 |
config.load_json(xtts_config)
|
49 |
MODEL = Xtts.init_from_config(config)
|
50 |
MODEL.load_checkpoint(
|
51 |
+
config, checkpoint_dir=checkpoint_dir, use_deepspeed=use_deepspeed
|
52 |
)
|
53 |
if torch.cuda.is_available():
|
54 |
MODEL.cuda()
|
55 |
|
|
|
56 |
supported_languages = config.languages
|
57 |
+
if not "vi" in supported_languages:
|
58 |
supported_languages.append("vi")
|
59 |
|
60 |
+
|
61 |
def normalize_vietnamese_text(text):
|
62 |
text = (
|
63 |
TTSnorm(text, unknown=False, lower=False, rule=True)
|
|
|
73 |
)
|
74 |
return text
|
75 |
|
76 |
+
|
77 |
def calculate_keep_len(text, lang):
|
78 |
+
"""Simple hack for short sentences"""
|
79 |
if lang in ["ja", "zh-cn"]:
|
80 |
return -1
|
81 |
|
|
|
88 |
return 13000 * word_count + 2000 * num_punct
|
89 |
return -1
|
90 |
|
91 |
+
|
92 |
+
@spaces.GPU
|
93 |
+
def predict(
|
94 |
+
prompt,
|
95 |
+
language,
|
96 |
+
audio_file_pth,
|
97 |
+
normalize_text=True,
|
98 |
+
):
|
99 |
if language not in supported_languages:
|
100 |
+
metrics_text = gr.Warning(
|
101 |
+
f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
|
102 |
+
)
|
103 |
+
|
104 |
+
return (None, metrics_text)
|
105 |
|
106 |
speaker_wav = audio_file_pth
|
107 |
|
108 |
+
if len(prompt) < 2:
|
109 |
+
metrics_text = gr.Warning("Please give a longer prompt text")
|
110 |
+
return (None, metrics_text)
|
111 |
+
|
112 |
+
if len(prompt) > 250000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000:
|
113 |
+
metrics_text = gr.Warning(
|
114 |
+
str(len(prompt))
|
115 |
+
+ " characters.\n"
|
116 |
+
+ "Your prompt is too long, please keep it under 250 characters\n"
|
117 |
+
+ "Văn bản quá dài, vui lòng giữ dưới 250 ký tự."
|
118 |
+
)
|
119 |
+
return (None, metrics_text)
|
120 |
+
|
121 |
try:
|
122 |
+
metrics_text = ""
|
123 |
+
t_latent = time.time()
|
124 |
+
|
125 |
+
try:
|
126 |
+
(
|
127 |
+
gpt_cond_latent,
|
128 |
+
speaker_embedding,
|
129 |
+
) = MODEL.get_conditioning_latents(
|
130 |
+
audio_path=speaker_wav,
|
131 |
+
gpt_cond_len=30,
|
132 |
+
gpt_cond_chunk_len=4,
|
133 |
+
max_ref_length=60,
|
134 |
+
)
|
135 |
+
|
136 |
+
except Exception as e:
|
137 |
+
print("Speaker encoding error", str(e))
|
138 |
+
metrics_text = gr.Warning(
|
139 |
+
"It appears something wrong with reference, did you unmute your microphone?"
|
140 |
+
)
|
141 |
+
return (None, metrics_text)
|
142 |
+
|
143 |
+
prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
|
144 |
|
|
|
145 |
if normalize_text and language == "vi":
|
146 |
prompt = normalize_vietnamese_text(prompt)
|
147 |
|
148 |
+
print("I: Generating new audio...")
|
149 |
+
t0 = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
out = MODEL.inference(
|
151 |
prompt,
|
152 |
language,
|
|
|
156 |
temperature=0.75,
|
157 |
enable_text_splitting=True,
|
158 |
)
|
|
|
|
|
159 |
inference_time = time.time() - t0
|
160 |
+
print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
|
161 |
+
metrics_text += (
|
162 |
+
f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
|
163 |
+
)
|
164 |
real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
|
165 |
+
print(f"Real-time factor (RTF): {real_time_factor}")
|
166 |
+
metrics_text += f"Real-time factor (RTF): {real_time_factor:.2f}\n"
|
167 |
|
168 |
+
# Temporary hack for short sentences
|
169 |
keep_len = calculate_keep_len(prompt, language)
|
170 |
out["wav"] = out["wav"][:keep_len]
|
171 |
|
|
|
172 |
torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
|
173 |
|
174 |
except RuntimeError as e:
|
175 |
+
if "device-side assert" in str(e):
|
176 |
+
# cannot do anything on cuda device side error, need tor estart
|
177 |
+
print(
|
178 |
+
f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
|
179 |
+
flush=True,
|
180 |
+
)
|
181 |
+
gr.Warning("Unhandled Exception encounter, please retry in a minute")
|
182 |
+
print("Cuda device-assert Runtime encountered need restart")
|
183 |
+
|
184 |
+
error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
|
185 |
+
error_data = [
|
186 |
+
error_time,
|
187 |
+
prompt,
|
188 |
+
language,
|
189 |
+
audio_file_pth,
|
190 |
+
]
|
191 |
+
error_data = [str(e) if type(e) != str else e for e in error_data]
|
192 |
+
print(error_data)
|
193 |
+
print(speaker_wav)
|
194 |
+
write_io = StringIO()
|
195 |
+
csv.writer(write_io).writerows([error_data])
|
196 |
+
csv_upload = write_io.getvalue().encode()
|
197 |
+
|
198 |
+
filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
|
199 |
+
print("Writing error csv")
|
200 |
+
error_api = HfApi()
|
201 |
+
error_api.upload_file(
|
202 |
+
path_or_fileobj=csv_upload,
|
203 |
+
path_in_repo=filename,
|
204 |
+
repo_id="coqui/xtts-flagged-dataset",
|
205 |
+
repo_type="dataset",
|
206 |
+
)
|
207 |
+
|
208 |
+
# speaker_wav
|
209 |
+
print("Writing error reference audio")
|
210 |
+
speaker_filename = error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
|
211 |
+
error_api = HfApi()
|
212 |
+
error_api.upload_file(
|
213 |
+
path_or_fileobj=speaker_wav,
|
214 |
+
path_in_repo=speaker_filename,
|
215 |
+
repo_id="coqui/xtts-flagged-dataset",
|
216 |
+
repo_type="dataset",
|
217 |
+
)
|
218 |
+
|
219 |
+
# HF Space specific.. This error is unrecoverable need to restart space
|
220 |
+
space = api.get_space_runtime(repo_id=repo_id)
|
221 |
+
if space.stage != "BUILDING":
|
222 |
+
api.restart_space(repo_id=repo_id)
|
223 |
+
else:
|
224 |
+
print("TRIED TO RESTART but space is building")
|
225 |
+
|
226 |
+
else:
|
227 |
+
if "Failed to decode" in str(e):
|
228 |
+
print("Speaker encoding error", str(e))
|
229 |
+
metrics_text = gr.Warning(
|
230 |
+
metrics_text="It appears something wrong with reference, did you unmute your microphone?"
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
print("RuntimeError: non device-side assert error:", str(e))
|
234 |
+
metrics_text = gr.Warning(
|
235 |
+
"Something unexpected happened please retry again."
|
236 |
+
)
|
237 |
+
return (None, metrics_text)
|
238 |
+
return ("output.wav", metrics_text)
|
239 |
+
|
240 |
+
|
241 |
+
with gr.Blocks(analytics_enabled=False) as demo:
|
242 |
+
with gr.Row():
|
243 |
+
with gr.Column():
|
244 |
+
gr.Markdown(
|
245 |
+
"""
|
246 |
+
# viXTTS Demo ✨
|
247 |
+
- Github: https://github.com/thinhlpg/vixtts-demo/
|
248 |
+
- viVoice: https://github.com/thinhlpg/viVoice
|
249 |
+
"""
|
250 |
+
)
|
251 |
+
with gr.Column():
|
252 |
+
# placeholder to align the image
|
253 |
+
pass
|
254 |
+
|
255 |
+
with gr.Row():
|
256 |
+
with gr.Column():
|
257 |
+
input_text_gr = gr.Textbox(
|
258 |
+
label="Text Prompt (Văn bản cần đọc)",
|
259 |
+
info="Mỗi câu nên từ 10 từ trở lên. Tối đa 250 ký tự (khoảng 2 - 3 câu).",
|
260 |
+
value="Xin chào, tôi là một mô hình chuyển đổi văn bản thành giọng nói tiếng Việt.",
|
261 |
+
)
|
262 |
+
language_gr = gr.Dropdown(
|
263 |
+
label="Language (Ngôn ngữ)",
|
264 |
+
choices=[
|
265 |
+
"vi",
|
266 |
+
"en",
|
267 |
+
"es",
|
268 |
+
"fr",
|
269 |
+
"de",
|
270 |
+
"it",
|
271 |
+
"pt",
|
272 |
+
"pl",
|
273 |
+
"tr",
|
274 |
+
"ru",
|
275 |
+
"nl",
|
276 |
+
"cs",
|
277 |
+
"ar",
|
278 |
+
"zh-cn",
|
279 |
+
"ja",
|
280 |
+
"ko",
|
281 |
+
"hu",
|
282 |
+
"hi",
|
283 |
+
],
|
284 |
+
max_choices=1,
|
285 |
+
value="vi",
|
286 |
+
)
|
287 |
+
normalize_text = gr.Checkbox(
|
288 |
+
label="Chuẩn hóa văn bản tiếng Việt",
|
289 |
+
info="Normalize Vietnamese text",
|
290 |
+
value=True,
|
291 |
+
)
|
292 |
+
ref_gr = gr.Audio(
|
293 |
+
label="Reference Audio (Giọng mẫu)",
|
294 |
+
type="filepath",
|
295 |
+
value="model/samples/nu-luu-loat.wav",
|
296 |
+
)
|
297 |
+
tts_button = gr.Button(
|
298 |
+
"Đọc 🗣️🔥",
|
299 |
+
elem_id="send-btn",
|
300 |
+
visible=True,
|
301 |
+
variant="primary",
|
302 |
+
)
|
303 |
+
|
304 |
+
with gr.Column():
|
305 |
+
audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
|
306 |
+
out_text_gr = gr.Text(label="Metrics")
|
307 |
+
|
308 |
+
tts_button.click(
|
309 |
+
predict,
|
310 |
+
[
|
311 |
+
input_text_gr,
|
312 |
+
language_gr,
|
313 |
+
ref_gr,
|
314 |
+
normalize_text,
|
315 |
+
],
|
316 |
+
outputs=[audio_gr, out_text_gr],
|
317 |
+
api_name="predict",
|
318 |
+
)
|
319 |
|
320 |
+
demo.queue()
|
321 |
+
demo.launch(debug=True, show_api=True, share=True)
|