Spaces:
Build error
Build error
import os | |
import shutil | |
import torch | |
import traceback | |
from pathlib import Path | |
from multiprocessing import cpu_count | |
from functions.core_functions1 import clear_gpu_cache | |
from functions.logging_utils import remove_log_file | |
from functions.slice_utils import open_slice, close_slice | |
from utils.formatter import format_audio_list | |
from utils.gpt_train import train_gpt | |
def get_audio_files_from_folder(folder_path): | |
audio_files = [] | |
for root, dirs, files in os.walk(folder_path): | |
for file in files: | |
if file.endswith(".wav") or file.endswith(".mp3") or file.endswith(".flac") or file.endswith(".m4a") or file.endswith(".webm"): | |
audio_files.append(os.path.join(root, file)) | |
return audio_files | |
def preprocess_dataset(audio_path, audio_folder, language, whisper_model, out_path, train_csv, eval_csv, progress): | |
out_path = os.path.join(out_path, "dataset") | |
os.makedirs(out_path, exist_ok=True) | |
if audio_path is not None and audio_path != []: | |
try: | |
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, whisper_model=whisper_model, target_language=language, out_path=out_path, gradio_progress=progress) | |
except: | |
traceback.print_exc() | |
error = traceback.format_exc() | |
return f"The data processing was interrupted due to an error! Please check the console to verify the full error message! \n Error summary: {error}", "", "" | |
elif audio_folder is not None: | |
audio_files = get_audio_files_from_folder(audio_folder) | |
try: | |
train_meta, eval_meta, audio_total_size = format_audio_list(audio_files, whisper_model=whisper_model, target_language=language, out_path=out_path, gradio_progress=progress) | |
except: | |
traceback.print_exc() | |
error = traceback.format_exc() | |
return f"The data processing was interrupted due to an error! Please check the console to verify the full error message! \n Error summary: {error}", "", "" | |
else: | |
return "You should provide either audio files or a folder containing audio files!", "", "" | |
if audio_total_size < 120: | |
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" | |
print(message) | |
return message, "", "" | |
print("Dataset Processed!") | |
return "Dataset Processed!", train_meta, eval_meta | |
def train_model(custom_model, version, language, train_csv, eval_csv, num_epochs, batch_size, grad_accum, output_path, max_audio_length): | |
run_dir = Path(output_path) / "run" | |
if run_dir.exists(): | |
os.remove(run_dir) | |
if not train_csv or not eval_csv: | |
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields!", "", "", "", "", "" | |
try: | |
max_audio_length = int(max_audio_length * 22050) | |
speaker_xtts_path, config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(custom_model, version, language, num_epochs, batch_size, grad_accum, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) | |
except: | |
traceback.print_exc() | |
error = traceback.format_exc() | |
return f"The training was interrupted due to an error! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "", "" | |
ready_dir = Path(output_path) / "ready" | |
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") | |
shutil.copy(ft_xtts_checkpoint, ready_dir / "unoptimize_model.pth") | |
ft_xtts_checkpoint = os.path.join(ready_dir, "unoptimize_model.pth") | |
speaker_reference_path = Path(speaker_wav) | |
speaker_reference_new_path = ready_dir / "reference.wav" | |
shutil.copy(speaker_reference_path, speaker_reference_new_path) | |
print("Model training done!") | |
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_xtts_path, speaker_reference_new_path | |
def optimize_model(out_path, clear_train_data): | |
out_path = Path(out_path) | |
ready_dir = out_path / "ready" | |
run_dir = out_path / "run" | |
dataset_dir = out_path / "dataset" | |
if clear_train_data in {"run", "all"} and run_dir.exists(): | |
try: | |
shutil.rmtree(run_dir) | |
except PermissionError as e: | |
print(f"An error occurred while deleting {run_dir}: {e}") | |
if clear_train_data in {"dataset", "all"} and dataset_dir.exists(): | |
try: | |
shutil.rmtree(dataset_dir) | |
except PermissionError as e: | |
print(f"An error occurred while deleting {dataset_dir}: {e}") | |
model_path = ready_dir / "unoptimize_model.pth" | |
if not model_path.is_file(): | |
return "Unoptimized model not found in ready folder", "" | |
checkpoint = torch.load(model_path, map_location=torch.device("cpu")) | |
del checkpoint["optimizer"] | |
for key in list(checkpoint["model"].keys()): | |
if "dvae" in key: | |
del checkpoint["model"][key] | |
os.remove(model_path) | |
optimized_model_file_name = "model.pth" | |
optimized_model = ready_dir / optimized_model_file_name | |
torch.save(checkpoint, optimized_model) | |
ft_xtts_checkpoint = str(optimized_model) | |
return f"Model optimized and saved at {ft_xtts_checkpoint}!", ft_xtts_checkpoint | |
def load_params(out_path): | |
path_output = Path(out_path) | |
dataset_path = path_output / "dataset" | |
if not dataset_path.exists(): | |
return "The output folder does not exist!", "", "", "" | |
eval_train = dataset_path / "metadata_train.csv" | |
eval_csv = dataset_path / "metadata_eval.csv" | |
lang_file_path = dataset_path / "lang.txt" | |
current_language = None | |
if os.path.exists(lang_file_path): | |
with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file: | |
current_language = existing_lang_file.read().strip() | |
print(current_language) | |
return "The data has been updated", eval_train, eval_csv, current_language | |
def load_params_tts(out_path, version): | |
path_output = Path(out_path) | |
ready_dir = path_output / "ready" | |
xtts_config_path = ready_dir / "config.json" | |
xtts_vocab_path = ready_dir / "vocab.json" | |
xtts_checkpoint_path = ready_dir / "model.pth" | |
xtts_speaker_path = ready_dir / "speaker.pth" | |
speaker_reference_path = ready_dir / "reference.wav" | |
missing_files = [] | |
if not xtts_config_path.exists(): | |
missing_files.append(str(xtts_config_path)) | |
if not xtts_vocab_path.exists(): | |
missing_files.append(str(xtts_vocab_path)) | |
if not xtts_checkpoint_path.exists(): | |
missing_files.append(str(xtts_checkpoint_path)) | |
if not xtts_speaker_path.exists(): | |
missing_files.append(str(xtts_speaker_path)) | |
if not speaker_reference_path.exists(): | |
missing_files.append(str(speaker_reference_path)) | |
if missing_files: | |
return f"The following files are missing from the ready folder: {', '.join(missing_files)}", "", "", "", "", "" | |
print("Loaded parameters for TTS.") | |
return "Loaded parameters for TTS.", str(xtts_checkpoint_path), str(xtts_config_path), str(xtts_vocab_path), str(xtts_speaker_path), str(speaker_reference_path) |