import argparse import codecs import re import tempfile from pathlib import Path import numpy as np import soundfile as sf import tomli import torch import torchaudio import tqdm from cached_path import cached_path from einops import rearrange from pydub import AudioSegment, silence from transformers import pipeline from vocos import Vocos from model import CFM, DiT, MMDiT, UNetT from model.utils import (convert_char_to_pinyin, get_tokenizer, load_checkpoint, save_spectrogram) parser = argparse.ArgumentParser( prog="python3 inference-cli.py", description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.", epilog="Specify options above to override one or more settings from config.", ) parser.add_argument( "-c", "--config", help="Configuration file. Default=cli-config.toml", default="inference-cli.toml", ) parser.add_argument( "-m", "--model", help="F5-TTS | E2-TTS", ) parser.add_argument( "-p", "--ckpt_file", help="The Checkpoint .pt", ) parser.add_argument( "-v", "--vocab_file", help="The vocab .txt", ) parser.add_argument( "-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds." ) parser.add_argument( "-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio." ) parser.add_argument( "-t", "--gen_text", type=str, help="Text to generate.", ) parser.add_argument( "-f", "--gen_file", type=str, help="File with text to generate. Ignores --text", ) parser.add_argument( "-o", "--output_dir", type=str, help="Path to output folder..", ) parser.add_argument( "--remove_silence", help="Remove silence.", ) parser.add_argument( "--load_vocoder_from_local", action="store_true", help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz", ) args = parser.parse_args() config = tomli.load(open(args.config, "rb")) ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"] ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"] gen_text = args.gen_text if args.gen_text else config["gen_text"] gen_file = args.gen_file if args.gen_file else config["gen_file"] if gen_file: gen_text = codecs.open(gen_file, "r", "utf-8").read() output_dir = args.output_dir if args.output_dir else config["output_dir"] model = args.model if args.model else config["model"] ckpt_file = args.ckpt_file if args.ckpt_file else "" vocab_file = args.vocab_file if args.vocab_file else "" remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"] wave_path = Path(output_dir)/"out.wav" spectrogram_path = Path(output_dir)/"out.png" vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" device = ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) if args.load_vocoder_from_local: print(f"Load vocos from local path {vocos_local_path}") vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) vocos.load_state_dict(state_dict) vocos.eval() else: print("Download Vocos from huggingface charactr/vocos-mel-24khz") vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") print(f"Using {device} device") # --------------------- Settings -------------------- # target_sample_rate = 24000 n_mel_channels = 100 hop_length = 256 target_rms = 0.1 nfe_step = 32 # 16, 32 cfg_strength = 2.0 ode_method = "euler" sway_sampling_coef = -1.0 speed = 1.0 # fix_duration = 27 # None or float (duration in seconds) fix_duration = None def load_model(model_cls, model_cfg, ckpt_path,file_vocab): if file_vocab=="": file_vocab="Emilia_ZH_EN" tokenizer="pinyin" else: tokenizer="custom" print("\nvocab : ", vocab_file,tokenizer) print("tokenizer : ", tokenizer) print("model : ", ckpt_path,"\n") vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer) model = CFM( transformer=model_cls( **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels ), mel_spec_kwargs=dict( target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length, ), odeint_kwargs=dict( method=ode_method, ), vocab_char_map=vocab_char_map, ).to(device) model = load_checkpoint(model, ckpt_path, device, use_ema = True) return model # load models F5TTS_model_cfg = dict( dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4 ) E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) if model == "F5-TTS": if ckpt_file == "": repo_name= "F5-TTS" exp_name = "F5TTS_Base" ckpt_step= 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,vocab_file) elif model == "E2-TTS": if ckpt_file == "": repo_name= "E2-TTS" exp_name = "E2TTS_Base" ckpt_step= 1200000 ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,vocab_file) asr_pipe = pipeline( "automatic-speech-recognition", model="openai/whisper-large-v3-turbo", torch_dtype=torch.float16, device=device, ) def chunk_text(text, max_chars=135): """ Splits the input text into chunks, each with a maximum number of characters. Args: text (str): The text to be split. max_chars (int): The maximum number of characters per chunk. Returns: List[str]: A list of text chunks. """ chunks = [] current_chunk = "" # Split the text into sentences based on punctuation followed by whitespace sentences = re.split(r'(?<=[;:,.!?])\s+|(?<=[;:,。!?])', text) for sentence in sentences: if len(current_chunk.encode('utf-8')) + len(sentence.encode('utf-8')) <= max_chars: current_chunk += sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence else: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = sentence + " " if sentence and len(sentence[-1].encode('utf-8')) == 1 else sentence if current_chunk: chunks.append(current_chunk.strip()) return chunks #ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors #if not Path(ckpt_path).exists(): #ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors")) def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15): audio, sr = ref_audio if audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True) rms = torch.sqrt(torch.mean(torch.square(audio))) if rms < target_rms: audio = audio * target_rms / rms if sr != target_sample_rate: resampler = torchaudio.transforms.Resample(sr, target_sample_rate) audio = resampler(audio) audio = audio.to(device) generated_waves = [] spectrograms = [] if len(ref_text[-1].encode('utf-8')) == 1: ref_text = ref_text + " " for i, gen_text in enumerate(tqdm.tqdm(gen_text_batches)): # Prepare the text text_list = [ref_text + gen_text] final_text_list = convert_char_to_pinyin(text_list) # Calculate duration ref_audio_len = audio.shape[-1] // hop_length ref_text_len = len(ref_text.encode('utf-8')) gen_text_len = len(gen_text.encode('utf-8')) duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) # inference with torch.inference_mode(): generated, _ = ema_model.sample( cond=audio, text=final_text_list, duration=duration, steps=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, ) generated = generated.to(torch.float32) generated = generated[:, ref_audio_len:, :] generated_mel_spec = rearrange(generated, "1 n d -> 1 d n") generated_wave = vocos.decode(generated_mel_spec.cpu()) if rms < target_rms: generated_wave = generated_wave * rms / target_rms # wav -> numpy generated_wave = generated_wave.squeeze().cpu().numpy() generated_waves.append(generated_wave) spectrograms.append(generated_mel_spec[0].cpu().numpy()) # Combine all generated waves with cross-fading if cross_fade_duration <= 0: # Simply concatenate final_wave = np.concatenate(generated_waves) else: final_wave = generated_waves[0] for i in range(1, len(generated_waves)): prev_wave = final_wave next_wave = generated_waves[i] # Calculate cross-fade samples, ensuring it does not exceed wave lengths cross_fade_samples = int(cross_fade_duration * target_sample_rate) cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave)) if cross_fade_samples <= 0: # No overlap possible, concatenate final_wave = np.concatenate([prev_wave, next_wave]) continue # Overlapping parts prev_overlap = prev_wave[-cross_fade_samples:] next_overlap = next_wave[:cross_fade_samples] # Fade out and fade in fade_out = np.linspace(1, 0, cross_fade_samples) fade_in = np.linspace(0, 1, cross_fade_samples) # Cross-faded overlap cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in # Combine new_wave = np.concatenate([ prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:] ]) final_wave = new_wave # Create a combined spectrogram combined_spectrogram = np.concatenate(spectrograms, axis=1) return final_wave, combined_spectrogram def process_voice(ref_audio_orig, ref_text): print("Converting", ref_audio_orig) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: aseg = AudioSegment.from_file(ref_audio_orig) non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=1000) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: non_silent_wave += non_silent_seg aseg = non_silent_wave audio_duration = len(aseg) if audio_duration > 15000: print("Audio is over 15s, clipping to only first 15s.") aseg = aseg[:15000] aseg.export(f.name, format="wav") ref_audio = f.name if not ref_text.strip(): print("No reference text provided, transcribing reference audio...") ref_text = asr_pipe( ref_audio, chunk_length_s=30, batch_size=128, generate_kwargs={"task": "transcribe"}, return_timestamps=False, )["text"].strip() print("Finished transcription") else: print("Using custom reference text...") return ref_audio, ref_text def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15): # Add the functionality to ensure it ends with ". " if not ref_text.endswith(". ") and not ref_text.endswith("。"): if ref_text.endswith("."): ref_text += " " else: ref_text += ". " # Split the input text into batches audio, sr = torchaudio.load(ref_audio) max_chars = int(len(ref_text.encode('utf-8')) / (audio.shape[-1] / sr) * (25 - audio.shape[-1] / sr)) gen_text_batches = chunk_text(gen_text, max_chars=max_chars) for i, gen_text in enumerate(gen_text_batches): print(f'gen_text {i}', gen_text) print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...") return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration) def process(ref_audio, ref_text, text_gen, model, remove_silence): main_voice = {"ref_audio":ref_audio, "ref_text":ref_text} if "voices" not in config: voices = {"main": main_voice} else: voices = config["voices"] voices["main"] = main_voice for voice in voices: voices[voice]['ref_audio'], voices[voice]['ref_text'] = process_voice(voices[voice]['ref_audio'], voices[voice]['ref_text']) print("Voice:", voice) print("Ref_audio:", voices[voice]['ref_audio']) print("Ref_text:", voices[voice]['ref_text']) generated_audio_segments = [] reg1 = r'(?=\[\w+\])' chunks = re.split(reg1, text_gen) reg2 = r'\[(\w+)\]' for text in chunks: match = re.match(reg2, text) if not match or voice not in voices: voice = "main" else: voice = match[1] text = re.sub(reg2, "", text) gen_text = text.strip() ref_audio = voices[voice]['ref_audio'] ref_text = voices[voice]['ref_text'] print(f"Voice: {voice}") audio, spectragram = infer(ref_audio, ref_text, gen_text, model,remove_silence) generated_audio_segments.append(audio) if generated_audio_segments: final_wave = np.concatenate(generated_audio_segments) with open(wave_path, "wb") as f: sf.write(f.name, final_wave, target_sample_rate) # Remove silence if remove_silence: aseg = AudioSegment.from_file(f.name) non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500) non_silent_wave = AudioSegment.silent(duration=0) for non_silent_seg in non_silent_segs: non_silent_wave += non_silent_seg aseg = non_silent_wave aseg.export(f.name, format="wav") print(f.name) process(ref_audio, ref_text, gen_text, model, remove_silence)