Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import torch | |
from fireredtts.modules.gpt.gpt import GPT | |
from fireredtts.modules import Token2Wav, MelSpectrogramExtractor | |
from fireredtts.modules.tokenizer.tokenizer import VoiceBpeTokenizer | |
from fireredtts.modules.codec.speaker import SpeakerEmbedddingExtractor | |
from fireredtts.utils.utils import load_audio | |
import time | |
class FireRedTTS: | |
def __init__(self, config_path, pretrained_path, device="cuda"): | |
self.device = device | |
self.config = json.load(open(config_path)) | |
self.gpt_path = os.path.join(pretrained_path, "fireredtts_gpt.pt") | |
self.token2wav_path = os.path.join(pretrained_path, "fireredtts_token2wav.pt") | |
self.speaker_extractor_path = os.path.join( | |
pretrained_path, "fireredtts_speaker.bin" | |
) | |
assert os.path.exists(self.token2wav_path) | |
assert os.path.exists(self.gpt_path) | |
assert os.path.exists(self.speaker_extractor_path) | |
# tokenizer; | |
self.text_tokenizer = VoiceBpeTokenizer() | |
# speaker ectractor | |
self.speaker_extractor = SpeakerEmbedddingExtractor( | |
ckpt_path=self.speaker_extractor_path, device=device | |
) | |
# load gpt model | |
self.gpt = GPT( | |
start_text_token=self.config["gpt"]["gpt_start_text_token"], | |
stop_text_token=self.config["gpt"]["gpt_stop_text_token"], | |
layers=self.config["gpt"]["gpt_layers"], | |
model_dim=self.config["gpt"]["gpt_n_model_channels"], | |
heads=self.config["gpt"]["gpt_n_heads"], | |
max_text_tokens=self.config["gpt"]["gpt_max_text_tokens"], | |
max_mel_tokens=self.config["gpt"]["gpt_max_audio_tokens"], | |
max_prompt_tokens=self.config["gpt"]["gpt_max_prompt_tokens"], | |
code_stride_len=self.config["gpt"]["gpt_code_stride_len"], | |
number_text_tokens=self.config["gpt"]["gpt_number_text_tokens"], | |
num_audio_tokens=self.config["gpt"]["gpt_num_audio_tokens"], | |
start_audio_token=self.config["gpt"]["gpt_start_audio_token"], | |
stop_audio_token=self.config["gpt"]["gpt_stop_audio_token"], | |
) | |
sd = torch.load(self.gpt_path, map_location=device)["model"] | |
self.gpt.load_state_dict(sd, strict=True) | |
self.gpt = self.gpt.to(device=device) | |
self.gpt.eval() | |
self.gpt.init_gpt_for_inference(kv_cache=True) | |
# mel-spectrogram extractor | |
self.mel_extractor = MelSpectrogramExtractor() | |
# load token2wav model | |
self.token2wav = Token2Wav.init_from_config(self.config) | |
sd = torch.load(self.token2wav_path, map_location="cpu") | |
self.token2wav.load_state_dict(sd, strict=True) | |
self.token2wav.generator.remove_weight_norm() | |
self.token2wav.eval() | |
self.token2wav = self.token2wav.to(device) | |
def extract_spk_embeddings(self, prompt_wav): | |
_, _, audio_resampled = load_audio(audiopath=prompt_wav, sampling_rate=16000) | |
audio_len = torch.tensor( | |
data=[audio_resampled.shape[1]], dtype=torch.long, requires_grad=False | |
) | |
# speaker embeddings [1,512] | |
spk_embeddings = self.speaker_extractor( | |
audio_resampled.to(device="cuda") | |
).unsqueeze(0) | |
return spk_embeddings | |
def do_gpt_inference(self, spk_gpt, text_tokens): | |
"""_summary_ | |
Args: | |
spk_gpt (_type_): speaker embeddidng in gpt | |
text_tokens (_type_): text tokens | |
""" | |
with torch.no_grad(): | |
gpt_codes = self.gpt.generate( | |
cond_latents=spk_gpt, | |
text_inputs=text_tokens, | |
input_tokens=None, | |
do_sample=True, | |
top_p=0.85, | |
top_k=30, | |
temperature=0.75, | |
num_return_sequences=9, | |
num_beams=1, | |
length_penalty=1.0, | |
repetition_penalty=2.0, | |
output_attentions=False, | |
) | |
seqs = [] | |
EOS_TOKEN = self.config["gpt"]["gpt_stop_audio_token"] | |
for seq in gpt_codes: | |
index = (seq == EOS_TOKEN).nonzero(as_tuple=True)[0][0] | |
seq = seq[:index] | |
seqs.append(seq) | |
sorted_seqs = sorted(seqs, key=lambda i: len(i), reverse=False) | |
gpt_codes = sorted_seqs[2].unsqueeze(0) # [1, len] | |
# sorted_len = [len(l) for l in sorted_seqs] | |
# print("---sorted_len:", sorted_len) | |
return gpt_codes | |
def synthesize(self, prompt_wav, text, lang="auto"): | |
"""_summary_ | |
Args: | |
prompts_wav (_type_): prompts_wav path | |
text (_type_): text | |
lang (_type_): language of text | |
""" | |
# Currently only supports Chinese and English | |
assert lang in ["zh", "en", "auto"] | |
assert os.path.exists(prompt_wav) | |
# text to tokens | |
text_tokens = self.text_tokenizer.encode(text=text, lang=lang) | |
text_tokens = torch.IntTensor(text_tokens).unsqueeze(0).to(self.device) | |
assert text_tokens.shape[-1] < 400 | |
# extract speaker embedding | |
spk_embeddings = self.extract_spk_embeddings(prompt_wav=prompt_wav).unsqueeze(0) | |
with torch.no_grad(): | |
spk_gpt = self.gpt.reference_embedding(spk_embeddings) | |
# gpt inference | |
gpt_start_time = time.time() | |
gpt_codes = self.do_gpt_inference(spk_gpt=spk_gpt, text_tokens=text_tokens) | |
gpt_end_time = time.time() | |
gpt_dur = gpt_end_time - gpt_start_time | |
# prompt mel-spectrogram compute | |
prompt_mel = ( | |
self.mel_extractor(wav_path=prompt_wav).unsqueeze(0).to(self.device) | |
) | |
# convert token to waveform (b=1, t) | |
voc_start_time = time.time() | |
rec_wavs = self.token2wav.inference(gpt_codes, prompt_mel, n_timesteps=10) | |
voc_end_time = time.time() | |
voc_dur = voc_end_time - voc_start_time | |
all_dur = voc_end_time - gpt_start_time | |
# rtf compute | |
# audio_dur = rec_wavs.shape[-1] / 24000 | |
# rtf_gpt = gpt_dur / audio_dur | |
# rtf_voc = voc_dur / audio_dur | |
# rtf_all = all_dur / audio_dur | |
return rec_wavs | |