|
import base64
|
|
import io
|
|
import os
|
|
import tempfile
|
|
import wave
|
|
import torch
|
|
import numpy as np
|
|
from typing import List
|
|
from pydantic import BaseModel
|
|
import spaces
|
|
|
|
from TTS.tts.configs.xtts_config import XttsConfig
|
|
from TTS.tts.models.xtts import Xtts
|
|
from trainer.io import get_user_data_dir
|
|
from TTS.utils.manage import ModelManager
|
|
|
|
os.environ["COQUI_TOS_AGREED"] = "1"
|
|
|
|
torch.set_num_threads(int(os.environ.get("NUM_THREADS", os.cpu_count())))
|
|
device = torch.device("cuda" if os.environ.get("USE_CPU", "0") == "0" else "cpu")
|
|
if not torch.cuda.is_available() and device == "cuda":
|
|
raise RuntimeError("CUDA device unavailable, please use Dockerfile.cpu instead.")
|
|
|
|
custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models")
|
|
|
|
if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"):
|
|
model_path = custom_model_path
|
|
print("Loading custom model from", model_path, flush=True)
|
|
else:
|
|
print("Loading default model", flush=True)
|
|
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
|
|
print("Downloading XTTS Model:", model_name, flush=True)
|
|
ModelManager().download_model(model_name)
|
|
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
|
|
print("XTTS Model downloaded", flush=True)
|
|
|
|
print("Loading XTTS", flush=True)
|
|
config = XttsConfig()
|
|
config.load_json(os.path.join(model_path, "config.json"))
|
|
model = Xtts.init_from_config(config)
|
|
model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=True if device == "cuda" else False)
|
|
model.to(device)
|
|
print("XTTS Loaded.", flush=True)
|
|
|
|
print("Running XTTS Server ...", flush=True)
|
|
|
|
|
|
|
|
|
|
@spaces.GPU
|
|
def predict_speaker(wav_file):
|
|
"""Compute conditioning inputs from reference audio file."""
|
|
|
|
if isinstance(wav_file, str):
|
|
wav_file = open(wav_file,"rb");
|
|
|
|
|
|
temp_audio_name = next(tempfile._get_candidate_names())
|
|
with open(temp_audio_name, "wb") as temp, torch.inference_mode():
|
|
temp.write(io.BytesIO(wav_file.read()).getbuffer())
|
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
|
|
temp_audio_name
|
|
)
|
|
return {
|
|
"gpt_cond_latent": gpt_cond_latent.cpu().squeeze().half().tolist(),
|
|
"speaker_embedding": speaker_embedding.cpu().squeeze().half().tolist(),
|
|
}
|
|
|
|
|
|
def postprocess(wav):
|
|
"""Post process the output waveform"""
|
|
if isinstance(wav, list):
|
|
wav = torch.cat(wav, dim=0)
|
|
wav = wav.clone().detach().cpu().numpy()
|
|
wav = wav[None, : int(wav.shape[0])]
|
|
wav = np.clip(wav, -1, 1)
|
|
wav = (wav * 32767).astype(np.int16)
|
|
return wav
|
|
|
|
|
|
def encode_audio_common(
|
|
frame_input, encode_base64=True, sample_rate=24000, sample_width=2, channels=1
|
|
):
|
|
"""Return base64 encoded audio"""
|
|
wav_buf = io.BytesIO()
|
|
with wave.open(wav_buf, "wb") as vfout:
|
|
vfout.setnchannels(channels)
|
|
vfout.setsampwidth(sample_width)
|
|
vfout.setframerate(sample_rate)
|
|
vfout.writeframes(frame_input)
|
|
|
|
wav_buf.seek(0)
|
|
if encode_base64:
|
|
b64_encoded = base64.b64encode(wav_buf.getbuffer()).decode("utf-8")
|
|
return b64_encoded
|
|
else:
|
|
return wav_buf.read()
|
|
|
|
|
|
class StreamingInputs(BaseModel):
|
|
speaker_embedding: List[float]
|
|
gpt_cond_latent: List[List[float]]
|
|
text: str
|
|
language: str
|
|
add_wav_header: bool = True
|
|
stream_chunk_size: str = "20"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TTSInputs(BaseModel):
|
|
speaker_embedding: List[float]
|
|
gpt_cond_latent: List[List[float]]
|
|
text: str
|
|
language: str
|
|
temperature: float
|
|
speed: float
|
|
top_k: int
|
|
top_p: float
|
|
|
|
|
|
@spaces.GPU
|
|
def predict_speech(parsed_input: TTSInputs):
|
|
speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
|
|
gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
|
|
|
|
text = parsed_input.text
|
|
language = parsed_input.language
|
|
temperature = parsed_input.temperature
|
|
speed = parsed_input.speed
|
|
top_k = parsed_input.top_k
|
|
top_p = parsed_input.top_p
|
|
length_penalty = 1.0
|
|
repetition_penalty= 2.0
|
|
|
|
|
|
out = model.inference(
|
|
text,
|
|
language,
|
|
gpt_cond_latent,
|
|
speaker_embedding,
|
|
temperature,
|
|
length_penalty,
|
|
repetition_penalty,
|
|
top_k,
|
|
top_p,
|
|
speed,
|
|
)
|
|
|
|
wav = postprocess(torch.tensor(out["wav"]))
|
|
|
|
return encode_audio_common(wav.tobytes())
|
|
|
|
|
|
|
|
def get_speakers():
|
|
if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"):
|
|
return {
|
|
speaker: {
|
|
"speaker_embedding": model.speaker_manager.speakers[speaker]["speaker_embedding"].cpu().squeeze().half().tolist(),
|
|
"gpt_cond_latent": model.speaker_manager.speakers[speaker]["gpt_cond_latent"].cpu().squeeze().half().tolist(),
|
|
}
|
|
for speaker in model.speaker_manager.speakers.keys()
|
|
}
|
|
else:
|
|
return {}
|
|
|
|
|
|
def get_languages():
|
|
return config.languages |