import os import shutil import torch import traceback from pathlib import Path from multiprocessing import cpu_count from functions.core_functions1 import clear_gpu_cache from functions.logging_utils import remove_log_file from functions.slice_utils import open_slice, close_slice from utils.formatter import format_audio_list from utils.gpt_train import train_gpt def get_audio_files_from_folder(folder_path): audio_files = [] for root, dirs, files in os.walk(folder_path): for file in files: if file.endswith(".wav") or file.endswith(".mp3") or file.endswith(".flac") or file.endswith(".m4a") or file.endswith(".webm"): audio_files.append(os.path.join(root, file)) return audio_files def preprocess_dataset(audio_path, audio_folder, language, whisper_model, out_path, train_csv, eval_csv, progress): out_path = os.path.join(out_path, "dataset") os.makedirs(out_path, exist_ok=True) if audio_path is not None and audio_path != []: try: train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, whisper_model=whisper_model, target_language=language, out_path=out_path, gradio_progress=progress) except: traceback.print_exc() error = traceback.format_exc() return f"The data processing was interrupted due to an error! Please check the console to verify the full error message! \n Error summary: {error}", "", "" elif audio_folder is not None: audio_files = get_audio_files_from_folder(audio_folder) try: train_meta, eval_meta, audio_total_size = format_audio_list(audio_files, whisper_model=whisper_model, target_language=language, out_path=out_path, gradio_progress=progress) except: traceback.print_exc() error = traceback.format_exc() return f"The data processing was interrupted due to an error! Please check the console to verify the full error message! \n Error summary: {error}", "", "" else: return "You should provide either audio files or a folder containing audio files!", "", "" if audio_total_size < 120: message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" print(message) return message, "", "" print("Dataset Processed!") return "Dataset Processed!", train_meta, eval_meta def train_model(custom_model, version, language, train_csv, eval_csv, num_epochs, batch_size, grad_accum, output_path, max_audio_length): run_dir = Path(output_path) / "run" if run_dir.exists(): os.remove(run_dir) if not train_csv or not eval_csv: return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields!", "", "", "", "", "" try: max_audio_length = int(max_audio_length * 22050) speaker_xtts_path, config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(custom_model, version, language, num_epochs, batch_size, grad_accum, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) except: traceback.print_exc() error = traceback.format_exc() return f"The training was interrupted due to an error! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "", "" ready_dir = Path(output_path) / "ready" ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") shutil.copy(ft_xtts_checkpoint, ready_dir / "unoptimize_model.pth") ft_xtts_checkpoint = os.path.join(ready_dir, "unoptimize_model.pth") speaker_reference_path = Path(speaker_wav) speaker_reference_new_path = ready_dir / "reference.wav" shutil.copy(speaker_reference_path, speaker_reference_new_path) print("Model training done!") return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_xtts_path, speaker_reference_new_path def optimize_model(out_path, clear_train_data): out_path = Path(out_path) ready_dir = out_path / "ready" run_dir = out_path / "run" dataset_dir = out_path / "dataset" if clear_train_data in {"run", "all"} and run_dir.exists(): try: shutil.rmtree(run_dir) except PermissionError as e: print(f"An error occurred while deleting {run_dir}: {e}") if clear_train_data in {"dataset", "all"} and dataset_dir.exists(): try: shutil.rmtree(dataset_dir) except PermissionError as e: print(f"An error occurred while deleting {dataset_dir}: {e}") model_path = ready_dir / "unoptimize_model.pth" if not model_path.is_file(): return "Unoptimized model not found in ready folder", "" checkpoint = torch.load(model_path, map_location=torch.device("cpu")) del checkpoint["optimizer"] for key in list(checkpoint["model"].keys()): if "dvae" in key: del checkpoint["model"][key] os.remove(model_path) optimized_model_file_name = "model.pth" optimized_model = ready_dir / optimized_model_file_name torch.save(checkpoint, optimized_model) ft_xtts_checkpoint = str(optimized_model) return f"Model optimized and saved at {ft_xtts_checkpoint}!", ft_xtts_checkpoint def load_params(out_path): path_output = Path(out_path) dataset_path = path_output / "dataset" if not dataset_path.exists(): return "The output folder does not exist!", "", "", "" eval_train = dataset_path / "metadata_train.csv" eval_csv = dataset_path / "metadata_eval.csv" lang_file_path = dataset_path / "lang.txt" current_language = None if os.path.exists(lang_file_path): with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file: current_language = existing_lang_file.read().strip() print(current_language) return "The data has been updated", eval_train, eval_csv, current_language def load_params_tts(out_path, version): path_output = Path(out_path) ready_dir = path_output / "ready" xtts_config_path = ready_dir / "config.json" xtts_vocab_path = ready_dir / "vocab.json" xtts_checkpoint_path = ready_dir / "model.pth" xtts_speaker_path = ready_dir / "speaker.pth" speaker_reference_path = ready_dir / "reference.wav" missing_files = [] if not xtts_config_path.exists(): missing_files.append(str(xtts_config_path)) if not xtts_vocab_path.exists(): missing_files.append(str(xtts_vocab_path)) if not xtts_checkpoint_path.exists(): missing_files.append(str(xtts_checkpoint_path)) if not xtts_speaker_path.exists(): missing_files.append(str(xtts_speaker_path)) if not speaker_reference_path.exists(): missing_files.append(str(speaker_reference_path)) if missing_files: return f"The following files are missing from the ready folder: {', '.join(missing_files)}", "", "", "", "", "" print("Loaded parameters for TTS.") return "Loaded parameters for TTS.", str(xtts_checkpoint_path), str(xtts_config_path), str(xtts_vocab_path), str(xtts_speaker_path), str(speaker_reference_path)