import os import json import torch from fireredtts.modules.gpt.gpt import GPT from fireredtts.modules import Token2Wav, MelSpectrogramExtractor from fireredtts.modules.tokenizer.tokenizer import VoiceBpeTokenizer from fireredtts.modules.codec.speaker import SpeakerEmbedddingExtractor from fireredtts.utils.utils import load_audio import time class FireRedTTS: def __init__(self, config_path, pretrained_path, device="cuda"): self.device = device self.config = json.load(open(config_path)) self.gpt_path = os.path.join(pretrained_path, "fireredtts_gpt.pt") self.token2wav_path = os.path.join(pretrained_path, "fireredtts_token2wav.pt") self.speaker_extractor_path = os.path.join( pretrained_path, "fireredtts_speaker.bin" ) assert os.path.exists(self.token2wav_path) assert os.path.exists(self.gpt_path) assert os.path.exists(self.speaker_extractor_path) # tokenizer; self.text_tokenizer = VoiceBpeTokenizer() # speaker ectractor self.speaker_extractor = SpeakerEmbedddingExtractor( ckpt_path=self.speaker_extractor_path, device=device ) # load gpt model self.gpt = GPT( start_text_token=self.config["gpt"]["gpt_start_text_token"], stop_text_token=self.config["gpt"]["gpt_stop_text_token"], layers=self.config["gpt"]["gpt_layers"], model_dim=self.config["gpt"]["gpt_n_model_channels"], heads=self.config["gpt"]["gpt_n_heads"], max_text_tokens=self.config["gpt"]["gpt_max_text_tokens"], max_mel_tokens=self.config["gpt"]["gpt_max_audio_tokens"], max_prompt_tokens=self.config["gpt"]["gpt_max_prompt_tokens"], code_stride_len=self.config["gpt"]["gpt_code_stride_len"], number_text_tokens=self.config["gpt"]["gpt_number_text_tokens"], num_audio_tokens=self.config["gpt"]["gpt_num_audio_tokens"], start_audio_token=self.config["gpt"]["gpt_start_audio_token"], stop_audio_token=self.config["gpt"]["gpt_stop_audio_token"], ) sd = torch.load(self.gpt_path, map_location=device)["model"] self.gpt.load_state_dict(sd, strict=True) self.gpt = self.gpt.to(device=device) self.gpt.eval() self.gpt.init_gpt_for_inference(kv_cache=True) # mel-spectrogram extractor self.mel_extractor = MelSpectrogramExtractor() # load token2wav model self.token2wav = Token2Wav.init_from_config(self.config) sd = torch.load(self.token2wav_path, map_location="cpu") self.token2wav.load_state_dict(sd, strict=True) self.token2wav.generator.remove_weight_norm() self.token2wav.eval() self.token2wav = self.token2wav.to(device) def extract_spk_embeddings(self, prompt_wav): _, _, audio_resampled = load_audio(audiopath=prompt_wav, sampling_rate=16000) audio_len = torch.tensor( data=[audio_resampled.shape[1]], dtype=torch.long, requires_grad=False ) # speaker embeddings [1,512] spk_embeddings = self.speaker_extractor( audio_resampled.to(device="cuda") ).unsqueeze(0) return spk_embeddings def do_gpt_inference(self, spk_gpt, text_tokens): """_summary_ Args: spk_gpt (_type_): speaker embeddidng in gpt text_tokens (_type_): text tokens """ with torch.no_grad(): gpt_codes = self.gpt.generate( cond_latents=spk_gpt, text_inputs=text_tokens, input_tokens=None, do_sample=True, top_p=0.85, top_k=30, temperature=0.75, num_return_sequences=9, num_beams=1, length_penalty=1.0, repetition_penalty=2.0, output_attentions=False, ) seqs = [] EOS_TOKEN = self.config["gpt"]["gpt_stop_audio_token"] for seq in gpt_codes: index = (seq == EOS_TOKEN).nonzero(as_tuple=True)[0][0] seq = seq[:index] seqs.append(seq) sorted_seqs = sorted(seqs, key=lambda i: len(i), reverse=False) gpt_codes = sorted_seqs[2].unsqueeze(0) # [1, len] # sorted_len = [len(l) for l in sorted_seqs] # print("---sorted_len:", sorted_len) return gpt_codes def synthesize(self, prompt_wav, text, lang="auto"): """_summary_ Args: prompts_wav (_type_): prompts_wav path text (_type_): text lang (_type_): language of text """ # Currently only supports Chinese and English assert lang in ["zh", "en", "auto"] assert os.path.exists(prompt_wav) # text to tokens text_tokens = self.text_tokenizer.encode(text=text, lang=lang) text_tokens = torch.IntTensor(text_tokens).unsqueeze(0).to(self.device) assert text_tokens.shape[-1] < 400 # extract speaker embedding spk_embeddings = self.extract_spk_embeddings(prompt_wav=prompt_wav).unsqueeze(0) with torch.no_grad(): spk_gpt = self.gpt.reference_embedding(spk_embeddings) # gpt inference gpt_start_time = time.time() gpt_codes = self.do_gpt_inference(spk_gpt=spk_gpt, text_tokens=text_tokens) gpt_end_time = time.time() gpt_dur = gpt_end_time - gpt_start_time # prompt mel-spectrogram compute prompt_mel = ( self.mel_extractor(wav_path=prompt_wav).unsqueeze(0).to(self.device) ) # convert token to waveform (b=1, t) voc_start_time = time.time() rec_wavs = self.token2wav.inference(gpt_codes, prompt_mel, n_timesteps=10) voc_end_time = time.time() voc_dur = voc_end_time - voc_start_time all_dur = voc_end_time - gpt_start_time # rtf compute # audio_dur = rec_wavs.shape[-1] / 24000 # rtf_gpt = gpt_dur / audio_dur # rtf_voc = voc_dur / audio_dur # rtf_all = all_dur / audio_dur return rec_wavs