Spaces:
Build error
Build error
"""Inference logic. | |
Copyright PolyAI Limited. | |
""" | |
import argparse | |
import json | |
import logging | |
import os | |
import time | |
from pathlib import Path | |
import numpy as np | |
import soundfile as sf | |
import torch | |
from einops import rearrange | |
from librosa.util import normalize | |
from pyannote.audio import Inference | |
from transformers import GenerationConfig, T5ForConditionalGeneration | |
import constants as c | |
from data.collation import get_text_semantic_token_collater | |
from data.semantic_dataset import TextTokenizer | |
from modules.s2a_model import Pheme | |
from modules.vocoder import VocoderType | |
# How many times one token can be generated | |
MAX_TOKEN_COUNT = 100 | |
logging.basicConfig(level=logging.DEBUG) | |
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" | |
def parse_arguments(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--text", type=str, | |
default="I gotta say, I would never expect that to happen!" | |
) | |
parser.add_argument( | |
"--manifest_path", type=str, default="demo/manifest.json") | |
parser.add_argument("--outputdir", type=str, default="demo/") | |
parser.add_argument("--featuredir", type=str, default="demo/") | |
parser.add_argument( | |
"--text_tokens_file", type=str, | |
default="ckpt/unique_text_tokens.k2symbols" | |
) | |
parser.add_argument("--t2s_path", type=str, default="ckpt/t2s/") | |
parser.add_argument( | |
"--a2s_path", type=str, default="ckpt/s2a/s2a.ckpt") | |
parser.add_argument("--target_sample_rate", type=int, default=16_000) | |
parser.add_argument("--temperature", type=float, default=0.7) | |
parser.add_argument("--top_k", type=int, default=210) | |
parser.add_argument("--voice", type=str, default="male_voice") | |
return parser.parse_args() | |
class PhemeClient(): | |
def __init__(self, args): | |
self.args = args | |
self.outputdir = args.outputdir | |
self.target_sample_rate = args.target_sample_rate | |
self.featuredir = Path(args.featuredir).expanduser() | |
self.collater = get_text_semantic_token_collater(args.text_tokens_file) | |
self.phonemizer = TextTokenizer() | |
self.load_manifest(args.manifest_path) | |
# T2S model | |
self.t2s = T5ForConditionalGeneration.from_pretrained(args.t2s_path) | |
self.t2s = T5ForConditionalGeneration. | |
self.t2s.to(device) | |
self.t2s.eval() | |
# S2A model | |
self.s2a = Pheme.load_from_checkpoint(args.a2s_path) | |
self.s2a.to(device=device) | |
self.s2a.eval() | |
# Vocoder | |
vocoder = VocoderType["SPEECHTOKENIZER"].get_vocoder(None, None) | |
self.vocoder = vocoder.to(device) | |
self.vocoder.eval() | |
self.spkr_embedding = Inference( | |
"pyannote/embedding", | |
window="whole", | |
use_auth_token=os.environ["HUGGING_FACE_HUB_TOKEN"], | |
) | |
def load_manifest(self, input_path): | |
input_file = {} | |
with open(input_path, "rb") as f: | |
for line in f: | |
temp = json.loads(line) | |
input_file[temp["audio_filepath"].split(".wav")[0]] = temp | |
self.input_file = input_file | |
def lazy_decode(self, decoder_output, symbol_table): | |
semantic_tokens = map(lambda x: symbol_table[x], decoder_output) | |
semantic_tokens = [int(x) for x in semantic_tokens if x.isdigit()] | |
return np.array(semantic_tokens) | |
def infer_text(self, text, voice, sampling_config): | |
semantic_prompt = np.load(self.args.featuredir + "/audios-speech-tokenizer/semantic/" + f"{voice}.npy") # noqa | |
phones_seq = self.phonemizer(text)[0] | |
input_ids = self.collater([phones_seq]) | |
input_ids = input_ids.type(torch.IntTensor).to(device) | |
labels = [str(lbl) for lbl in semantic_prompt] | |
labels = self.collater([labels])[:, :-1] | |
decoder_input_ids = labels.to(device).long() | |
logging.debug(f"decoder_input_ids: {decoder_input_ids}") | |
counts = 1E10 | |
while (counts > MAX_TOKEN_COUNT): | |
output_ids = self.t2s.generate( | |
input_ids, decoder_input_ids=decoder_input_ids, | |
generation_config=sampling_config).sequences | |
# check repetitiveness | |
_, counts = torch.unique_consecutive(output_ids, return_counts=True) | |
counts = max(counts).item() | |
output_semantic = self.lazy_decode( | |
output_ids[0], self.collater.idx2token) | |
# remove the prompt | |
return output_semantic[len(semantic_prompt):].reshape(1, -1) | |
def _load_speaker_emb(self, element_id_prompt): | |
wav, _ = sf.read(self.featuredir / element_id_prompt) | |
audio = normalize(wav) * 0.95 | |
speaker_emb = self.spkr_embedding( | |
{ | |
"waveform": torch.FloatTensor(audio).unsqueeze(0), | |
"sample_rate": self.target_sample_rate | |
} | |
).reshape(1, -1) | |
return speaker_emb | |
def _load_prompt(self, prompt_file_path): | |
element_id_prompt = Path(prompt_file_path).stem | |
acoustic_path_prompt = self.featuredir / "audios-speech-tokenizer/acoustic" / f"{element_id_prompt}.npy" # noqa | |
semantic_path_prompt = self.featuredir / "audios-speech-tokenizer/semantic" / f"{element_id_prompt}.npy" # noqa | |
acoustic_prompt = np.load(acoustic_path_prompt).squeeze().T | |
semantic_prompt = np.load(semantic_path_prompt)[None] | |
return acoustic_prompt, semantic_prompt | |
def infer_acoustic(self, output_semantic, prompt_file_path): | |
semantic_tokens = output_semantic.reshape(1, -1) | |
acoustic_tokens = np.full( | |
[semantic_tokens.shape[1], 7], fill_value=c.PAD) | |
acoustic_prompt, semantic_prompt = self._load_prompt(prompt_file_path) # noqa | |
# Prepend prompt | |
acoustic_tokens = np.concatenate( | |
[acoustic_prompt, acoustic_tokens], axis=0) | |
semantic_tokens = np.concatenate([ | |
semantic_prompt, semantic_tokens], axis=1) | |
# Add speaker | |
acoustic_tokens = np.pad( | |
acoustic_tokens, [[1, 0], [0, 0]], constant_values=c.SPKR_1) | |
semantic_tokens = np.pad( | |
semantic_tokens, [[0,0], [1, 0]], constant_values=c.SPKR_1) | |
speaker_emb = None | |
if self.s2a.hp.use_spkr_emb: | |
speaker_emb = self._load_speaker_emb(prompt_file_path) | |
speaker_emb = np.repeat( | |
speaker_emb, semantic_tokens.shape[1], axis=0) | |
speaker_emb = torch.from_numpy(speaker_emb).to(device) | |
else: | |
speaker_emb = None | |
acoustic_tokens = torch.from_numpy( | |
acoustic_tokens).unsqueeze(0).to(device).long() | |
semantic_tokens = torch.from_numpy(semantic_tokens).to(device).long() | |
start_t = torch.tensor( | |
[acoustic_prompt.shape[0]], dtype=torch.long, device=device) | |
length = torch.tensor([ | |
semantic_tokens.shape[1]], dtype=torch.long, device=device) | |
codes = self.s2a.model.inference( | |
acoustic_tokens, | |
semantic_tokens, | |
start_t=start_t, | |
length=length, | |
maskgit_inference=True, | |
speaker_emb=speaker_emb | |
) | |
# Remove the prompt | |
synth_codes = codes[:, :, start_t:] | |
synth_codes = rearrange(synth_codes, "b c t -> c b t") | |
return synth_codes | |
def generate_audio(self, text, voice, sampling_config, prompt_file_path): | |
start_time = time.time() | |
output_semantic = self.infer_text( | |
text, voice, sampling_config | |
) | |
logging.debug(f"semantic_tokens: {time.time() - start_time}") | |
start_time = time.time() | |
codes = self.infer_acoustic(output_semantic, prompt_file_path) | |
logging.debug(f"acoustic_tokens: {time.time() - start_time}") | |
start_time = time.time() | |
audio_array = self.vocoder.decode(codes) | |
audio_array = rearrange(audio_array, "1 1 T -> T").cpu().numpy() | |
logging.debug(f"vocoder time: {time.time() - start_time}") | |
return audio_array | |
def infer( | |
self, text, voice="male_voice", temperature=0.7, | |
top_k=210, max_new_tokens=750, | |
): | |
sampling_config = GenerationConfig.from_pretrained( | |
self.args.t2s_path, | |
top_k=top_k, | |
num_beams=1, | |
do_sample=True, | |
temperature=temperature, | |
num_return_sequences=1, | |
max_new_tokens=max_new_tokens, | |
return_dict_in_generate=True, | |
output_scores=True | |
) | |
voice_data = self.input_file[voice] | |
prompt_file_path = voice_data["audio_prompt_filepath"] | |
text = voice_data["text"] + " " + text | |
audio_array = self.generate_audio( | |
text, voice, sampling_config, prompt_file_path) | |
return audio_array | |
if __name__ == "__main__": | |
args = parse_arguments() | |
args.outputdir = Path(args.outputdir).expanduser() | |
args.outputdir.mkdir(parents=True, exist_ok=True) | |
args.manifest_path = Path(args.manifest_path).expanduser() | |
client = PhemeClient(args) | |
audio_array = client.infer(args.text, voice=args.voice) | |
sf.write(os.path.join( | |
args.outputdir, f"{args.voice}.wav"), audio_array, | |
args.target_sample_rate | |
) | |