xtts_awesome / webui.py
awesome-paulw's picture
Upload folder using huggingface_hub
1207342 verified
# main.py
import gradio as gr
from gradio import State
from gradio_utils import *
from pathlib import Path
import argparse
from tools.i18n.i18n import I18nAuto
from config import is_share, webui_port_main
from functions.core_functions import convert_voice, process_srt_and_generate_audio, load_model, run_tts
from functions.slice_utils import open_slice, close_slice
from functions.logging_utils import remove_log_file, read_logs
from multiprocessing import cpu_count
import os
from subprocess import Popen
def launch():
parser = argparse.ArgumentParser(
description="""XTTS fine-tuning demo\n\n"""
"""
Example runs:
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
""",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--port",
type=int,
help="Port to run the gradio demo. Default: 5003",
default=5003,
)
parser.add_argument(
"--out_path",
type=str,
help="Output path (where data and checkpoints will be saved) Default: output/",
default=str(Path.cwd() / "finetune_models"),
)
parser.add_argument(
"--num_epochs",
type=int,
help="Number of epochs to train. Default: 6",
default=6,
)
parser.add_argument(
"--batch_size",
type=int,
help="Batch size. Default: 2",
default=2,
)
parser.add_argument(
"--grad_acumm",
type=int,
help="Grad accumulation steps. Default: 1",
default=1,
)
parser.add_argument(
"--max_audio_length",
type=int,
help="Max permitted audio size in seconds. Default: 11",
default=11,
)
args = parser.parse_args()
i18n = I18nAuto()
demo = gr.Blocks()
with demo:
with gr.Tab("0 - Audio Slicing"):
gr.Markdown(value=i18n("语音切分工具"))
with gr.Row():
slice_inp_path = gr.Textbox(label=i18n("音频自动切分输入路径,可文件可文件夹"), value="")
slice_opt_root = gr.Textbox(label=i18n("切分后的子音频的输出根目录"), value="output/slicer_opt")
threshold = gr.Textbox(label=i18n("threshold:音量小于这个值视作静音的备选切割点"), value="-34")
min_length = gr.Textbox(label=i18n("min_length:每段最小多长,如果第一段太短一直和后面段连起来直到超过这个值"), value="4000")
min_interval = gr.Textbox(label=i18n("min_interval:最短切割间隔"), value="300")
hop_size = gr.Textbox(label=i18n("hop_size:怎么算音量曲线,越小精度越大计算量越高(不是精度越大效果越好)"), value="10")
max_sil_kept = gr.Textbox(label=i18n("max_sil_kept:切完后静音最多留多长"), value="500")
with gr.Row():
open_slicer_button = gr.Button(i18n("开启语音切割"), variant="primary", visible=True)
close_slicer_button = gr.Button(i18n("终止语音切割"), variant="primary", visible=False)
_max = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("max:归一化后最大值多少"), value=0.9, interactive=True)
alpha = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("alpha_mix:混多少比例归一化后音频进来"), value=0.25, interactive=True)
n_process = gr.Slider(minimum=1, maximum=cpu_count(), step=1, label=i18n("切割使用的进程数"), value=4, interactive=True)
slicer_info = gr.Textbox(label=i18n("语音切割进程输出信息"))
open_slicer_button.click(open_slice, [slice_inp_path, slice_opt_root, threshold, min_length, min_interval, hop_size, max_sil_kept, _max, alpha, n_process], [slicer_info, open_slicer_button, close_slicer_button])
close_slicer_button.click(close_slice, [], [slicer_info, open_slicer_button, close_slicer_button])
with gr.Tab("1 - Data processing"):
out_path = gr.Textbox(label="Output path (where data and checkpoints will be saved):", value=args.out_path)
upload_file = gr.File(file_count="multiple", label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)")
folder_path = gr.Textbox(label="Or input the path of a folder containing audio files")
whisper_model = gr.Dropdown(label="Whisper Model", value="large-v3", choices=["large-v3", "large-v2", "large", "medium", "small"])
lang = gr.Dropdown(label="Dataset Language", value="en", choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh", "hu", "ko", "ja"])
progress_data = gr.Label(label="Progress:")
#train_csv = gr.Textbox(visible=False)
#eval_csv = gr.Textbox(visible=False)
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
train_csv_state = State()
eval_csv_state = State()
prompt_compute_btn.click(preprocess_dataset, inputs=[upload_file, folder_path, lang, whisper_model, out_path, train_csv_state, eval_csv_state], outputs=[progress_data, train_csv_state, eval_csv_state])
#prompt_compute_btn.click(preprocess_dataset, inputs=[upload_file, folder_path, lang, whisper_model, out_path, train_csv, eval_csv], outputs=[progress_data, train_csv, eval_csv])
with gr.Tab("2 - Fine-tuning XTTS Encoder"):
load_params_btn = gr.Button(value="Load Params from output folder")
version = gr.Dropdown(
label="XTTS base version",
value="v2.0.2",
choices=[
"v2.0.3",
"v2.0.2",
"v2.0.1",
"v2.0.0",
"main"
],
)
train_csv = gr.Textbox(
label="Train CSV:",
)
eval_csv = gr.Textbox(
label="Eval CSV:",
)
custom_model = gr.Textbox(
label="(Optional) Custom model.pth file , leave blank if you want to use the base file.",
value="",
)
num_epochs = gr.Slider(
label="Number of epochs:",
minimum=1,
maximum=100,
step=1,
value=args.num_epochs,
)
batch_size = gr.Slider(
label="Batch size:",
minimum=2,
maximum=512,
step=1,
value=args.batch_size,
)
grad_acumm = gr.Slider(
label="Grad accumulation steps:",
minimum=2,
maximum=128,
step=1,
value=args.grad_acumm,
)
max_audio_length = gr.Slider(
label="Max permitted audio size in seconds:",
minimum=2,
maximum=20,
step=1,
value=args.max_audio_length,
)
clear_train_data = gr.Dropdown(
label="Clear train data, you will delete selected folder, after optimizing",
value="run",
choices=[
"none",
"run",
"dataset",
"all"
])
progress_train = gr.Label(
label="Progress:"
)
train_btn = gr.Button(value="Step 2 - Run the training")
optimize_model_btn = gr.Button(value="Step 2.5 - Optimize the model")
load_params_btn.click(load_params, inputs=[out_path], outputs=[progress_train, train_csv, eval_csv, lang])
train_output_state = State()
optimize_output_state = State()
train_btn.click(train_model, inputs=[custom_model, version, lang, train_csv_state, eval_csv_state, num_epochs, batch_size, grad_acumm, out_path, max_audio_length], outputs=[progress_train, train_output_state])
optimize_model_btn.click(optimize_model, inputs=[out_path, clear_train_data], outputs=[progress_train, optimize_output_state])
#train_btn.click(train_model, inputs=[custom_model, version, lang, train_csv_state, eval_csv_state, num_epochs, batch_size, grad_acumm, out_path, max_audio_length], outputs=[progress_train, train_output_state])
# train_btn.click(train_model, inputs=[custom_model, version, lang, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, out_path, max_audio_length], outputs=[progress_train, "temp", "temp", "temp", "temp", "temp"])
#optimize_model_btn.click(optimize_model, inputs=[out_path, clear_train_data], outputs=[progress_train, "temp"])
with gr.Tab("3 - Inference"):
with gr.Row():
with gr.Column() as col1:
load_params_tts_btn = gr.Button(value="Load params for TTS from output folder")
xtts_checkpoint = gr.Textbox(
label="XTTS checkpoint path:",
value="",
)
xtts_config = gr.Textbox(
label="XTTS config path:",
value="",
)
xtts_vocab = gr.Textbox(
label="XTTS vocab path:",
value="",
)
xtts_speaker = gr.Textbox(
label="XTTS speaker path:",
value="",
)
progress_load = gr.Label(
label="Progress:"
)
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
with gr.Column() as col2:
speaker_reference_audio = gr.Textbox(
label="Speaker reference audio:",
value="",
)
tts_language = gr.Dropdown(
label="Language",
value="en",
choices=[
"en",
"es",
"fr",
"de",
"it",
"pt",
"pl",
"tr",
"ru",
"nl",
"cs",
"ar",
"zh",
"hu",
"ko",
"ja",
]
)
tts_text = gr.Textbox(
label="Input Text.",
value="This model sounds really good and above all, it's reasonably fast.",
)
with gr.Accordion("Advanced settings", open=False) as acr:
temperature = gr.Slider(
label="temperature",
minimum=0,
maximum=1,
step=0.05,
value=0.75,
)
length_penalty = gr.Slider(
label="length_penalty",
minimum=-10.0,
maximum=10.0,
step=0.5,
value=1,
)
repetition_penalty = gr.Slider(
label="repetition penalty",
minimum=1,
maximum=10,
step=0.5,
value=5,
)
top_k = gr.Slider(
label="top_k",
minimum=1,
maximum=100,
step=1,
value=50,
)
top_p = gr.Slider(
label="top_p",
minimum=0,
maximum=1,
step=0.05,
value=0.85,
)
sentence_split = gr.Checkbox(
label="Enable text splitting",
value=True,
)
use_config = gr.Checkbox(
label="Use Inference settings from config, if disabled use the settings above",
value=False,
)
tts_btn = gr.Button(value="Step 4 - Inference")
with gr.Column() as col3:
progress_gen = gr.Label(
label="Progress:"
)
tts_output_audio = gr.Audio(label="Generated Audio.")
reference_audio = gr.Audio(label="Reference audio used.")
with gr.Column() as col4:
srt_upload = gr.File(label="Upload SRT File")
generate_srt_audio_btn = gr.Button(value="Generate Audio from SRT")
srt_output_audio = gr.Audio(label="Combined Audio from SRT")
error_message = gr.Textbox(label="Error Message", visible=False)
generate_srt_audio_btn.click(process_srt_and_generate_audio, inputs=[srt_upload, tts_language, speaker_reference_audio, temperature, length_penalty, repetition_penalty, top_k, top_p, sentence_split, use_config], outputs=[srt_output_audio])
load_btn.click(load_model, inputs=[xtts_checkpoint, xtts_config, xtts_vocab, xtts_speaker], outputs=[progress_load])
tts_btn.click(run_tts, inputs=[tts_language, tts_text, speaker_reference_audio, temperature, length_penalty, repetition_penalty, top_k, top_p, sentence_split, use_config], outputs=[progress_gen, tts_output_audio, reference_audio])
load_params_tts_btn.click(load_params_tts, inputs=[out_path, version], outputs=[progress_load, xtts_checkpoint, xtts_config, xtts_vocab, xtts_speaker, speaker_reference_audio])
with gr.Tab("4 - Voice conversion"):
with gr.Column() as col0:
gr.Markdown("## OpenVoice Conversion Tool")
voice_convert_seed = gr.File(label="Upload Reference Speaker Audio being generated")
audio_to_convert = gr.Textbox(
label="Input the to-be-convert audio location",
value="",
)
convert_button = gr.Button("Convert Voice")
converted_audio = gr.Audio(label="Converted Audio")
convert_button.click(convert_voice, inputs=[voice_convert_seed, audio_to_convert], outputs=[converted_audio])
with gr.Tab("5 - Logs"):
# 添加一个按钮来读取日志
read_logs_btn = gr.Button("Read Logs")
log_output = gr.Textbox(label="Log Output")
read_logs_btn.click(fn=read_logs, inputs=None, outputs=log_output)
#demo.launch(share=is_share, server_port=webui_port_main, server_name="0.0.0.0")
demo.launch(
#share=False,
share=True,
debug=False,
server_port=args.port,
#server_name="localhost"
server_name="0.0.0.0"
)
'''
demo.launch(
server_name="0.0.0.0",
inbrowser=True,
share=is_share,
server_port=webui_port_main,
quiet=True,
)
'''
if __name__ == "__main__":
remove_log_file("logs/main.log")
launch()