Spaces:
Build error
Build error
import argparse | |
import os | |
import sys | |
import tempfile | |
import logging | |
from pathlib import Path | |
import os | |
import shutil | |
import glob | |
import gradio as gr | |
import librosa.display | |
import numpy as np | |
from datetime import datetime | |
from pydub import AudioSegment | |
import pysrt | |
import torch | |
import torchaudio | |
import traceback | |
from utils.formatter import format_audio_list, find_latest_best_model | |
from utils.gpt_train import train_gpt | |
from TTS.tts.configs.xtts_config import XttsConfig | |
from TTS.tts.models.xtts import Xtts | |
from openvoice_cli.downloader import download_checkpoint | |
from openvoice_cli.api import ToneColorConverter | |
import openvoice_cli.se_extractor as se_extractor | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Clear logs | |
def remove_log_file(file_path): | |
log_file = Path(file_path) | |
if log_file.exists() and log_file.is_file(): | |
log_file.unlink() | |
# remove_log_file(str(Path.cwd() / "log.out")) | |
def clear_gpu_cache(): | |
# clear the GPU cache | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
XTTS_MODEL = None | |
def load_model(xtts_checkpoint, xtts_config, xtts_vocab,xtts_speaker): | |
global XTTS_MODEL | |
clear_gpu_cache() | |
if not xtts_checkpoint or not xtts_config or not xtts_vocab: | |
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" | |
config = XttsConfig() | |
config.load_json(xtts_config) | |
XTTS_MODEL = Xtts.init_from_config(config) | |
print("Loading XTTS model! ") | |
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab,speaker_file_path=xtts_speaker, use_deepspeed=False) | |
if torch.cuda.is_available(): | |
XTTS_MODEL.cuda() | |
print("Model Loaded!") | |
return "Model Loaded!" | |
def run_tts(lang, tts_text, speaker_audio_file, output_file_path, temperature, length_penalty, repetition_penalty, top_k, top_p, speed, sentence_split, use_config): | |
if XTTS_MODEL is None: | |
raise Exception("XTTS_MODEL is not loaded. Please load the model before running TTS.") | |
if not tts_text.strip(): | |
raise ValueError("Text for TTS is empty.") | |
if not os.path.exists(speaker_audio_file): | |
raise FileNotFoundError(f"Speaker audio file not found: {speaker_audio_file}") | |
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) | |
if use_config: | |
out = XTTS_MODEL.inference( | |
text=tts_text, | |
language=lang, | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here | |
length_penalty=XTTS_MODEL.config.length_penalty, | |
repetition_penalty=XTTS_MODEL.config.repetition_penalty, | |
top_k=XTTS_MODEL.config.top_k, | |
top_p=XTTS_MODEL.config.top_p, | |
speed=speed, | |
enable_text_splitting = True | |
) | |
else: | |
out = XTTS_MODEL.inference( | |
text=tts_text, | |
language=lang, | |
gpt_cond_latent=gpt_cond_latent, | |
speaker_embedding=speaker_embedding, | |
temperature=temperature, # Add custom parameters here | |
length_penalty=length_penalty, | |
repetition_penalty=float(repetition_penalty), | |
top_k=top_k, | |
top_p=top_p, | |
speed=speed, | |
enable_text_splitting = sentence_split | |
) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: | |
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) | |
out_path = fp.name | |
torchaudio.save(out_path, out["wav"], 24000) | |
return "Speech generated !", out_path, speaker_audio_file | |
def load_params_tts(out_path,version): | |
out_path = Path(out_path) | |
# base_model_path = Path.cwd() / "models" / version | |
# if not base_model_path.exists(): | |
# return "Base model not found !","","","" | |
ready_model_path = out_path / "ready" | |
vocab_path = ready_model_path / "vocab.json" | |
config_path = ready_model_path / "config.json" | |
speaker_path = ready_model_path / "speakers_xtts.pth" | |
reference_path = ready_model_path / "reference.wav" | |
model_path = ready_model_path / "model.pth" | |
if not model_path.exists(): | |
model_path = ready_model_path / "unoptimize_model.pth" | |
if not model_path.exists(): | |
return "Params for TTS not found", "", "", "" | |
return "Params for TTS loaded", model_path, config_path, vocab_path,speaker_path, reference_path | |
def process_srt_and_generate_audio( | |
srt_file, | |
lang, | |
speaker_reference_audio, | |
temperature, | |
length_penalty, | |
repetition_penalty, | |
top_k, | |
top_p, | |
speed, | |
sentence_split, | |
use_config ): | |
try: | |
subtitles = pysrt.open(srt_file) | |
audio_files = [] | |
output_dir = create_output_dir(parent_dir='/content/drive/MyDrive/Voice Conversion Result') | |
for index, subtitle in enumerate(subtitles): | |
audio_filename = f"audio_{index+1:03d}.wav" | |
audio_file_path = os.path.join(output_dir, audio_filename) | |
subtitle_text=remove_endperiod(subtitle.text) | |
run_tts(lang, subtitle_text, speaker_reference_audio, audio_file_path, | |
temperature, length_penalty, repetition_penalty, top_k, top_p, | |
speed, sentence_split, use_config) | |
logger.info(f"Generated audio file: {audio_file_path}") | |
audio_files.append(audio_file_path) | |
output_audio_path = merge_audio_with_srt_timing(subtitles, audio_files, output_dir) | |
return output_audio_path | |
except Exception as e: | |
logger.error(f"Error in process_srt_and_generate_audio: {e}") | |
raise | |
def create_output_dir(parent_dir): | |
try: | |
# 定义一个基于当前日期和时间的文件夹名称 | |
folder_name = datetime.now().strftime("audio_outputs_%Y-%m-%d_%H-%M-%S") | |
# 定义父目录,这里假设在Colab的根目录 | |
#parent_dir = "/content/drive/MyDrive/Voice Conversion Result" | |
# 完整的文件夹路径 | |
output_dir = os.path.join(parent_dir, folder_name) | |
# 创建文件夹 | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
logger.info(f"Created output directory at: {output_dir}") | |
return output_dir | |
except Exception as e: | |
logger.error(f"Failed to create output directory: {e}") | |
raise | |
def srt_time_to_ms(srt_time): | |
return (srt_time.hours * 3600 + srt_time.minutes * 60 + srt_time.seconds) * 1000 + srt_time.milliseconds | |
def merge_audio_with_srt_timing(subtitles, audio_files, output_dir): | |
try: | |
combined = AudioSegment.silent(duration=0) | |
last_position_ms = 0 | |
for subtitle, audio_file in zip(subtitles, audio_files): | |
start_time_ms = srt_time_to_ms(subtitle.start) | |
if last_position_ms < start_time_ms: | |
silence_duration = start_time_ms - last_position_ms | |
combined += AudioSegment.silent(duration=silence_duration) | |
last_position_ms = start_time_ms | |
audio = AudioSegment.from_file(audio_file, format="wav") | |
combined += audio | |
last_position_ms += len(audio) | |
output_path = os.path.join(output_dir, "combined_audio_with_timing.wav") | |
#combined_with_set_frame_rate = combined.set_frame_rate(24000) | |
#combined_with_set_frame_rate.export(output_path, format="wav") | |
combined.export(output_path, format="wav") | |
logger.info(f"Exported combined audio to: {output_path}") | |
return output_path | |
except Exception as e: | |
logger.error(f"Error merging audio files: {e}") | |
raise | |
def remove_endperiod(subtitle): | |
"""Removes the period (.) at the end of a subtitle. | |
""" | |
if subtitle.endswith('.'): | |
subtitle = subtitle[:-1] | |
return subtitle | |
def convert_voice(reference_audio, audio_to_convert): | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# 定义输入和输出音频路径 | |
#input_audio_path = audio_to_convert | |
base_name, ext = os.path.splitext(os.path.basename(audio_to_convert)) | |
new_file_name = base_name + 'convertedvoice' + ext | |
output_path = os.path.join(os.path.dirname(audio_to_convert), new_file_name) | |
tune_one(input_file=audio_to_convert, ref_file=reference_audio, output_file=output_path, device=device) | |
return output_path | |
def tune_one(input_file,ref_file,output_file,device): | |
current_dir = os.path.dirname(os.path.realpath(__file__)) | |
checkpoints_dir = os.path.join(current_dir, 'checkpoints') | |
ckpt_converter = os.path.join(checkpoints_dir, 'converter') | |
if not os.path.exists(ckpt_converter): | |
os.makedirs(ckpt_converter, exist_ok=True) | |
download_checkpoint(ckpt_converter) | |
device = device | |
tone_color_converter = ToneColorConverter(os.path.join(ckpt_converter, 'config.json'), device=device) | |
tone_color_converter.load_ckpt(os.path.join(ckpt_converter, 'checkpoint.pth')) | |
source_se, _ = se_extractor.get_se(input_file, tone_color_converter, vad=True) | |
target_se, _ = se_extractor.get_se(ref_file, tone_color_converter, vad=True) | |
# Ensure output directory exists and is writable | |
output_dir = os.path.dirname(output_file) | |
if output_dir: | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir, exist_ok=True) | |
# Run the tone color converter | |
tone_color_converter.convert( | |
audio_src_path=input_file, | |
src_se=source_se, | |
tgt_se=target_se, | |
output_path=output_file, | |
) | |
''' | |
def tune_batch(input_dir, ref_file, output_dir=None, device='cpu', output_format='.wav'): | |
current_dir = os.path.dirname(os.path.realpath(__file__)) | |
checkpoints_dir = os.path.join(current_dir, 'checkpoints') | |
ckpt_converter = os.path.join(checkpoints_dir, 'converter') | |
if not os.path.exists(ckpt_converter): | |
os.makedirs(ckpt_converter, exist_ok=True) | |
download_checkpoint(ckpt_converter) | |
tone_color_converter = ToneColorConverter(os.path.join(ckpt_converter, 'config.json'), device=device) | |
tone_color_converter.load_ckpt(os.path.join(ckpt_converter, 'checkpoint.pth')) | |
target_se, _ = se_extractor.get_se(ref_file, tone_color_converter, vad=True) | |
# Use default output directory 'out' if not provided | |
if output_dir is None: | |
output_dir = os.path.join(current_dir, 'out') | |
os.makedirs(output_dir, exist_ok=True) | |
# Check for any audio files in the input directory (wav, mp3, flac) using glob | |
audio_extensions = ('*.wav', '*.mp3', '*.flac') | |
audio_files = [] | |
for extension in audio_extensions: | |
audio_files.extend(glob.glob(os.path.join(input_dir, extension))) | |
for audio_file in tqdm(audio_files,"Tune file",len(audio_files)): | |
# Extract source SE from audio file | |
source_se, _ = se_extractor.get_se(audio_file, tone_color_converter, vad=True) | |
# Run the tone color converter | |
filename_without_extension = os.path.splitext(os.path.basename(audio_file))[0] | |
output_filename = f"{filename_without_extension}_tuned{output_format}" | |
output_file = os.path.join(output_dir, output_filename) | |
tone_color_converter.convert( | |
audio_src_path=audio_file, | |
src_se=source_se, | |
tgt_se=target_se, | |
output_path=output_file, | |
) | |
print(f"Converted {audio_file} to {output_file}") | |
return output_dir | |
def main_single(args): | |
tune_one(input_file=args.input, ref_file=args.ref, output_file=args.output, device=args.device) | |
def main_batch(args): | |
output_dir = tune_batch( | |
input_dir=args.input_dir, | |
ref_file=args.ref_file, | |
output_dir=args.output_dir, | |
device=args.device, | |
output_format=args.output_format | |
) | |
print(f"Batch processing complete. Converted files are saved in {output_dir}") | |
''' | |
# define a logger to redirect | |
class Logger: | |
def __init__(self, filename="log.out"): | |
self.log_file = filename | |
self.terminal = sys.stdout | |
self.log = open(self.log_file, "w") | |
def write(self, message): | |
self.terminal.write(message) | |
self.log.write(message) | |
def flush(self): | |
self.terminal.flush() | |
self.log.flush() | |
def isatty(self): | |
return False | |
# redirect stdout and stderr to a file | |
sys.stdout = Logger() | |
sys.stderr = sys.stdout | |
# logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
import logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s [%(levelname)s] %(message)s", | |
handlers=[ | |
logging.StreamHandler(sys.stdout) | |
] | |
) | |
def read_logs(): | |
sys.stdout.flush() | |
with open(sys.stdout.log_file, "r") as f: | |
return f.read() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="""XTTS fine-tuning demo\n\n""" | |
""" | |
Example runs: | |
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port | |
""", | |
formatter_class=argparse.RawTextHelpFormatter, | |
) | |
parser.add_argument( | |
"--port", | |
type=int, | |
help="Port to run the gradio demo. Default: 5003", | |
default=5003, | |
) | |
parser.add_argument( | |
"--out_path", | |
type=str, | |
help="Output path (where data and checkpoints will be saved) Default: output/", | |
default=str(Path.cwd() / "finetune_models"), | |
) | |
parser.add_argument( | |
"--num_epochs", | |
type=int, | |
help="Number of epochs to train. Default: 6", | |
default=6, | |
) | |
parser.add_argument( | |
"--batch_size", | |
type=int, | |
help="Batch size. Default: 2", | |
default=2, | |
) | |
parser.add_argument( | |
"--grad_acumm", | |
type=int, | |
help="Grad accumulation steps. Default: 1", | |
default=1, | |
) | |
parser.add_argument( | |
"--max_audio_length", | |
type=int, | |
help="Max permitted audio size in seconds. Default: 11", | |
default=11, | |
) | |
args = parser.parse_args() | |
with gr.Blocks() as demo: | |
with gr.Tab("0 - Voice conversion"): | |
with gr.Column() as col0: | |
gr.Markdown("## OpenVoice Conversion Tool") | |
voice_convert_seed = gr.File(label="Upload Reference Speaker Audio being generated") | |
#pitch_shift_slider = gr.Slider(minimum=-12, maximum=12, step=1, value=0, label="Pitch Shift (Semitones)") | |
audio_to_convert = gr.Textbox( | |
label="Input the to-be-convert audio location", | |
value="", | |
) | |
convert_button = gr.Button("Convert Voice") | |
converted_audio = gr.Audio(label="Converted Audio") | |
convert_button.click( | |
convert_voice, | |
inputs=[voice_convert_seed, audio_to_convert], #, pitch_shift_slider], | |
outputs=[converted_audio] | |
) | |
with gr.Tab("1 - Data processing"): | |
out_path = gr.Textbox( | |
label="Output path (where data and checkpoints will be saved):", | |
value=args.out_path, | |
) | |
# upload_file = gr.Audio( | |
# sources="upload", | |
# label="Select here the audio files that you want to use for XTTS trainining !", | |
# type="filepath", | |
# ) | |
upload_file = gr.File( | |
file_count="multiple", | |
label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)", | |
) | |
whisper_model = gr.Dropdown( | |
label="Whisper Model", | |
value="large-v3", | |
choices=[ | |
"large-v3", | |
"large-v2", | |
"large", | |
"medium", | |
"small" | |
], | |
) | |
lang = gr.Dropdown( | |
label="Dataset Language", | |
value="en", | |
choices=[ | |
"en", | |
"es", | |
"fr", | |
"de", | |
"it", | |
"pt", | |
"pl", | |
"tr", | |
"ru", | |
"nl", | |
"cs", | |
"ar", | |
"zh", | |
"hu", | |
"ko", | |
"ja" | |
], | |
) | |
progress_data = gr.Label( | |
label="Progress:" | |
) | |
# demo.load(read_logs, None, logs, every=1) | |
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") | |
def preprocess_dataset(audio_path, language, whisper_model, out_path,train_csv,eval_csv, progress=gr.Progress(track_tqdm=True)): | |
clear_gpu_cache() | |
train_csv = "" | |
eval_csv = "" | |
out_path = os.path.join(out_path, "dataset") | |
os.makedirs(out_path, exist_ok=True) | |
if audio_path is None: | |
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" | |
else: | |
try: | |
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, whisper_model = whisper_model, target_language=language, out_path=out_path, gradio_progress=progress) | |
except: | |
traceback.print_exc() | |
error = traceback.format_exc() | |
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" | |
# clear_gpu_cache() | |
# if audio total len is less than 2 minutes raise an error | |
if audio_total_size < 120: | |
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" | |
print(message) | |
return message, "", "" | |
print("Dataset Processed!") | |
return "Dataset Processed!", train_meta, eval_meta | |
with gr.Tab("2 - Fine-tuning XTTS Encoder"): | |
load_params_btn = gr.Button(value="Load Params from output folder") | |
version = gr.Dropdown( | |
label="XTTS base version", | |
value="v2.0.2", | |
choices=[ | |
"v2.0.3", | |
"v2.0.2", | |
"v2.0.1", | |
"v2.0.0", | |
"main" | |
], | |
) | |
train_csv = gr.Textbox( | |
label="Train CSV:", | |
) | |
eval_csv = gr.Textbox( | |
label="Eval CSV:", | |
) | |
custom_model = gr.Textbox( | |
label="(Optional) Custom model.pth file , leave blank if you want to use the base file.", | |
value="", | |
) | |
num_epochs = gr.Slider( | |
label="Number of epochs:", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=args.num_epochs, | |
) | |
batch_size = gr.Slider( | |
label="Batch size:", | |
minimum=2, | |
maximum=512, | |
step=1, | |
value=args.batch_size, | |
) | |
grad_acumm = gr.Slider( | |
label="Grad accumulation steps:", | |
minimum=2, | |
maximum=128, | |
step=1, | |
value=args.grad_acumm, | |
) | |
max_audio_length = gr.Slider( | |
label="Max permitted audio size in seconds:", | |
minimum=2, | |
maximum=20, | |
step=1, | |
value=args.max_audio_length, | |
) | |
clear_train_data = gr.Dropdown( | |
label="Clear train data, you will delete selected folder, after optimizing", | |
value="run", | |
choices=[ | |
"none", | |
"run", | |
"dataset", | |
"all" | |
]) | |
progress_train = gr.Label( | |
label="Progress:" | |
) | |
# demo.load(read_logs, None, logs_tts_train, every=1) | |
train_btn = gr.Button(value="Step 2 - Run the training") | |
optimize_model_btn = gr.Button(value="Step 2.5 - Optimize the model") | |
def train_model(custom_model,version,language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): | |
clear_gpu_cache() | |
run_dir = Path(output_path) / "run" | |
# # Remove train dir | |
if run_dir.exists(): | |
os.remove(run_dir) | |
# Check if the dataset language matches the language you specified | |
lang_file_path = Path(output_path) / "dataset" / "lang.txt" | |
# Check if lang.txt already exists and contains a different language | |
current_language = None | |
if lang_file_path.exists(): | |
with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file: | |
current_language = existing_lang_file.read().strip() | |
if current_language != language: | |
print("The language that was prepared for the dataset does not match the specified language. Change the language to the one specified in the dataset") | |
language = current_language | |
if not train_csv or not eval_csv: | |
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" | |
try: | |
# convert seconds to waveform frames | |
max_audio_length = int(max_audio_length * 22050) | |
speaker_xtts_path,config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(custom_model,version,language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) | |
except: | |
traceback.print_exc() | |
error = traceback.format_exc() | |
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" | |
# copy original files to avoid parameters changes issues | |
# os.system(f"cp {config_path} {exp_path}") | |
# os.system(f"cp {vocab_file} {exp_path}") | |
ready_dir = Path(output_path) / "ready" | |
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") | |
shutil.copy(ft_xtts_checkpoint, ready_dir / "unoptimize_model.pth") | |
# os.remove(ft_xtts_checkpoint) | |
ft_xtts_checkpoint = os.path.join(ready_dir, "unoptimize_model.pth") | |
# Reference | |
# Move reference audio to output folder and rename it | |
speaker_reference_path = Path(speaker_wav) | |
speaker_reference_new_path = ready_dir / "reference.wav" | |
shutil.copy(speaker_reference_path, speaker_reference_new_path) | |
print("Model training done!") | |
# clear_gpu_cache() | |
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint,speaker_xtts_path, speaker_reference_new_path | |
def optimize_model(out_path, clear_train_data): | |
# print(out_path) | |
out_path = Path(out_path) # Ensure that out_path is a Path object. | |
ready_dir = out_path / "ready" | |
run_dir = out_path / "run" | |
dataset_dir = out_path / "dataset" | |
# Clear specified training data directories. | |
if clear_train_data in {"run", "all"} and run_dir.exists(): | |
try: | |
shutil.rmtree(run_dir) | |
except PermissionError as e: | |
print(f"An error occurred while deleting {run_dir}: {e}") | |
if clear_train_data in {"dataset", "all"} and dataset_dir.exists(): | |
try: | |
shutil.rmtree(dataset_dir) | |
except PermissionError as e: | |
print(f"An error occurred while deleting {dataset_dir}: {e}") | |
# Get full path to model | |
model_path = ready_dir / "unoptimize_model.pth" | |
if not model_path.is_file(): | |
return "Unoptimized model not found in ready folder", "" | |
# Load the checkpoint and remove unnecessary parts. | |
checkpoint = torch.load(model_path, map_location=torch.device("cpu")) | |
del checkpoint["optimizer"] | |
for key in list(checkpoint["model"].keys()): | |
if "dvae" in key: | |
del checkpoint["model"][key] | |
# Make sure out_path is a Path object or convert it to Path | |
os.remove(model_path) | |
# Save the optimized model. | |
optimized_model_file_name="model.pth" | |
optimized_model=ready_dir/optimized_model_file_name | |
torch.save(checkpoint, optimized_model) | |
ft_xtts_checkpoint=str(optimized_model) | |
clear_gpu_cache() | |
return f"Model optimized and saved at {ft_xtts_checkpoint}!", ft_xtts_checkpoint | |
def load_params(out_path): | |
path_output = Path(out_path) | |
dataset_path = path_output / "dataset" | |
if not dataset_path.exists(): | |
return "The output folder does not exist!", "", "" | |
eval_train = dataset_path / "metadata_train.csv" | |
eval_csv = dataset_path / "metadata_eval.csv" | |
# Write the target language to lang.txt in the output directory | |
lang_file_path = dataset_path / "lang.txt" | |
# Check if lang.txt already exists and contains a different language | |
current_language = None | |
if os.path.exists(lang_file_path): | |
with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file: | |
current_language = existing_lang_file.read().strip() | |
clear_gpu_cache() | |
print(current_language) | |
return "The data has been updated", eval_train, eval_csv, current_language | |
with gr.Tab("3 - Inference"): | |
with gr.Row(): | |
with gr.Column() as col1: | |
load_params_tts_btn = gr.Button(value="Load params for TTS from output folder") | |
xtts_checkpoint = gr.Textbox( | |
label="XTTS checkpoint path:", | |
value="", | |
) | |
xtts_config = gr.Textbox( | |
label="XTTS config path:", | |
value="", | |
) | |
xtts_vocab = gr.Textbox( | |
label="XTTS vocab path:", | |
value="", | |
) | |
xtts_speaker = gr.Textbox( | |
label="XTTS speaker path:", | |
value="", | |
) | |
progress_load = gr.Label( | |
label="Progress:" | |
) | |
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") | |
with gr.Column() as col2: | |
speaker_reference_audio = gr.Textbox( | |
label="Speaker reference audio:", | |
value="", | |
) | |
tts_language = gr.Dropdown( | |
label="Language", | |
value="en", | |
choices=[ | |
"en", | |
"es", | |
"fr", | |
"de", | |
"it", | |
"pt", | |
"pl", | |
"tr", | |
"ru", | |
"nl", | |
"cs", | |
"ar", | |
"zh", | |
"hu", | |
"ko", | |
"ja", | |
] | |
) | |
tts_text = gr.Textbox( | |
label="Input Text.", | |
value="This model sounds really good and above all, it's reasonably fast.", | |
) | |
with gr.Accordion("Advanced settings", open=False) as acr: | |
temperature = gr.Slider( | |
label="temperature", | |
minimum=0, | |
maximum=1, | |
step=0.05, | |
value=0.75, | |
) | |
length_penalty = gr.Slider( | |
label="length_penalty", | |
minimum=-10.0, | |
maximum=10.0, | |
step=0.5, | |
value=1, | |
) | |
repetition_penalty = gr.Slider( | |
label="repetition penalty", | |
minimum=1, | |
maximum=10, | |
step=0.5, | |
value=5, | |
) | |
top_k = gr.Slider( | |
label="top_k", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=50, | |
) | |
top_p = gr.Slider( | |
label="top_p", | |
minimum=0, | |
maximum=1, | |
step=0.05, | |
value=0.85, | |
) | |
speed = gr.Slider( | |
label="speed", | |
minimum=0.2, | |
maximum=4.0, | |
step=0.05, | |
value=1.0, | |
) | |
sentence_split = gr.Checkbox( | |
label="Enable text splitting", | |
value=True, | |
) | |
use_config = gr.Checkbox( | |
label="Use Inference settings from config, if disabled use the settings above", | |
value=False, | |
) | |
tts_btn = gr.Button(value="Step 4 - Inference") | |
with gr.Column() as col3: | |
progress_gen = gr.Label( | |
label="Progress:" | |
) | |
tts_output_audio = gr.Audio(label="Generated Audio.") | |
reference_audio = gr.Audio(label="Reference audio used.") | |
with gr.Column() as col4: | |
srt_upload = gr.File(label="Upload SRT File") | |
generate_srt_audio_btn = gr.Button(value="Generate Audio from SRT") | |
srt_output_audio = gr.Audio(label="Combined Audio from SRT") | |
error_message = gr.Textbox(label="Error Message", visible=False) # 错误消息组件,默认不显示 | |
generate_srt_audio_btn.click( | |
fn=process_srt_and_generate_audio, | |
inputs=[ | |
srt_upload, | |
tts_language, | |
speaker_reference_audio, | |
temperature, | |
length_penalty, | |
repetition_penalty, | |
top_k, | |
top_p, | |
speed, | |
sentence_split, | |
use_config | |
], | |
outputs=[srt_output_audio] | |
) | |
prompt_compute_btn.click( | |
fn=preprocess_dataset, | |
inputs=[ | |
upload_file, | |
lang, | |
whisper_model, | |
out_path, | |
train_csv, | |
eval_csv | |
], | |
outputs=[ | |
progress_data, | |
train_csv, | |
eval_csv, | |
], | |
) | |
load_params_btn.click( | |
fn=load_params, | |
inputs=[out_path], | |
outputs=[ | |
progress_train, | |
train_csv, | |
eval_csv, | |
lang | |
] | |
) | |
train_btn.click( | |
fn=train_model, | |
inputs=[ | |
custom_model, | |
version, | |
lang, | |
train_csv, | |
eval_csv, | |
num_epochs, | |
batch_size, | |
grad_acumm, | |
out_path, | |
max_audio_length, | |
], | |
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint,xtts_speaker, speaker_reference_audio], | |
) | |
optimize_model_btn.click( | |
fn=optimize_model, | |
inputs=[ | |
out_path, | |
clear_train_data | |
], | |
outputs=[progress_train,xtts_checkpoint], | |
) | |
load_btn.click( | |
fn=load_model, | |
inputs=[ | |
xtts_checkpoint, | |
xtts_config, | |
xtts_vocab, | |
xtts_speaker | |
], | |
outputs=[progress_load], | |
) | |
tts_btn.click( | |
fn=run_tts, | |
inputs=[ | |
tts_language, | |
tts_text, | |
speaker_reference_audio, | |
temperature, | |
length_penalty, | |
repetition_penalty, | |
top_k, | |
top_p, | |
speed, | |
sentence_split, | |
use_config | |
], | |
outputs=[progress_gen, tts_output_audio, reference_audio], | |
) | |
load_params_tts_btn.click( | |
fn=load_params_tts, | |
inputs=[ | |
out_path, | |
version | |
], | |
outputs=[progress_load,xtts_checkpoint,xtts_config,xtts_vocab,xtts_speaker,speaker_reference_audio], | |
) | |
demo.launch( | |
#share=False, | |
share=True, | |
debug=False, | |
server_port=args.port, | |
#server_name="localhost" | |
server_name="0.0.0.0" | |
) | |