FireRedTTS / fireredtts /fireredtts.py
hhguo's picture
update
37ced70
raw
history blame
6.21 kB
import os
import json
import torch
from fireredtts.modules.gpt.gpt import GPT
from fireredtts.modules import Token2Wav, MelSpectrogramExtractor
from fireredtts.modules.tokenizer.tokenizer import VoiceBpeTokenizer
from fireredtts.modules.codec.speaker import SpeakerEmbedddingExtractor
from fireredtts.utils.utils import load_audio
import time
class FireRedTTS:
def __init__(self, config_path, pretrained_path, device="cuda"):
self.device = device
self.config = json.load(open(config_path))
self.gpt_path = os.path.join(pretrained_path, "fireredtts_gpt.pt")
self.token2wav_path = os.path.join(pretrained_path, "fireredtts_token2wav.pt")
self.speaker_extractor_path = os.path.join(
pretrained_path, "fireredtts_speaker.bin"
)
assert os.path.exists(self.token2wav_path)
assert os.path.exists(self.gpt_path)
assert os.path.exists(self.speaker_extractor_path)
# tokenizer;
self.text_tokenizer = VoiceBpeTokenizer()
# speaker ectractor
self.speaker_extractor = SpeakerEmbedddingExtractor(
ckpt_path=self.speaker_extractor_path, device=device
)
# load gpt model
self.gpt = GPT(
start_text_token=self.config["gpt"]["gpt_start_text_token"],
stop_text_token=self.config["gpt"]["gpt_stop_text_token"],
layers=self.config["gpt"]["gpt_layers"],
model_dim=self.config["gpt"]["gpt_n_model_channels"],
heads=self.config["gpt"]["gpt_n_heads"],
max_text_tokens=self.config["gpt"]["gpt_max_text_tokens"],
max_mel_tokens=self.config["gpt"]["gpt_max_audio_tokens"],
max_prompt_tokens=self.config["gpt"]["gpt_max_prompt_tokens"],
code_stride_len=self.config["gpt"]["gpt_code_stride_len"],
number_text_tokens=self.config["gpt"]["gpt_number_text_tokens"],
num_audio_tokens=self.config["gpt"]["gpt_num_audio_tokens"],
start_audio_token=self.config["gpt"]["gpt_start_audio_token"],
stop_audio_token=self.config["gpt"]["gpt_stop_audio_token"],
)
sd = torch.load(self.gpt_path, map_location=device)["model"]
self.gpt.load_state_dict(sd, strict=True)
self.gpt = self.gpt.to(device=device)
self.gpt.eval()
self.gpt.init_gpt_for_inference(kv_cache=True)
# mel-spectrogram extractor
self.mel_extractor = MelSpectrogramExtractor()
# load token2wav model
self.token2wav = Token2Wav.init_from_config(self.config)
sd = torch.load(self.token2wav_path, map_location="cpu")
self.token2wav.load_state_dict(sd, strict=True)
self.token2wav.generator.remove_weight_norm()
self.token2wav.eval()
self.token2wav = self.token2wav.to(device)
def extract_spk_embeddings(self, prompt_wav):
_, _, audio_resampled = load_audio(audiopath=prompt_wav, sampling_rate=16000)
audio_len = torch.tensor(
data=[audio_resampled.shape[1]], dtype=torch.long, requires_grad=False
)
# speaker embeddings [1,512]
spk_embeddings = self.speaker_extractor(
audio_resampled.to(device="cuda")
).unsqueeze(0)
return spk_embeddings
def do_gpt_inference(self, spk_gpt, text_tokens):
"""_summary_
Args:
spk_gpt (_type_): speaker embeddidng in gpt
text_tokens (_type_): text tokens
"""
with torch.no_grad():
gpt_codes = self.gpt.generate(
cond_latents=spk_gpt,
text_inputs=text_tokens,
input_tokens=None,
do_sample=True,
top_p=0.85,
top_k=30,
temperature=0.75,
num_return_sequences=9,
num_beams=1,
length_penalty=1.0,
repetition_penalty=2.0,
output_attentions=False,
)
seqs = []
EOS_TOKEN = self.config["gpt"]["gpt_stop_audio_token"]
for seq in gpt_codes:
index = (seq == EOS_TOKEN).nonzero(as_tuple=True)[0][0]
seq = seq[:index]
seqs.append(seq)
sorted_seqs = sorted(seqs, key=lambda i: len(i), reverse=False)
gpt_codes = sorted_seqs[2].unsqueeze(0) # [1, len]
# sorted_len = [len(l) for l in sorted_seqs]
# print("---sorted_len:", sorted_len)
return gpt_codes
def synthesize(self, prompt_wav, text, lang="auto"):
"""_summary_
Args:
prompts_wav (_type_): prompts_wav path
text (_type_): text
lang (_type_): language of text
"""
# Currently only supports Chinese and English
assert lang in ["zh", "en", "auto"]
assert os.path.exists(prompt_wav)
# text to tokens
text_tokens = self.text_tokenizer.encode(text=text, lang=lang)
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0).to(self.device)
assert text_tokens.shape[-1] < 400
# extract speaker embedding
spk_embeddings = self.extract_spk_embeddings(prompt_wav=prompt_wav).unsqueeze(0)
with torch.no_grad():
spk_gpt = self.gpt.reference_embedding(spk_embeddings)
# gpt inference
gpt_start_time = time.time()
gpt_codes = self.do_gpt_inference(spk_gpt=spk_gpt, text_tokens=text_tokens)
gpt_end_time = time.time()
gpt_dur = gpt_end_time - gpt_start_time
# prompt mel-spectrogram compute
prompt_mel = (
self.mel_extractor(wav_path=prompt_wav).unsqueeze(0).to(self.device)
)
# convert token to waveform (b=1, t)
voc_start_time = time.time()
rec_wavs = self.token2wav.inference(gpt_codes, prompt_mel, n_timesteps=10)
voc_end_time = time.time()
voc_dur = voc_end_time - voc_start_time
all_dur = voc_end_time - gpt_start_time
# rtf compute
# audio_dur = rec_wavs.shape[-1] / 24000
# rtf_gpt = gpt_dur / audio_dur
# rtf_voc = voc_dur / audio_dur
# rtf_all = all_dur / audio_dur
return rec_wavs