xtts / xtts.py
rrg92's picture
Minor fix to avoid ZeroGPU pickle error due buffered binary file
ac9a77d
import base64
import io
import os
import tempfile
import wave
import torch
import numpy as np
from typing import List
from pydantic import BaseModel
import spaces
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from trainer.io import get_user_data_dir
from TTS.utils.manage import ModelManager
os.environ["COQUI_TOS_AGREED"] = "1"
torch.set_num_threads(int(os.environ.get("NUM_THREADS", os.cpu_count())))
device = torch.device("cuda" if os.environ.get("USE_CPU", "0") == "0" else "cpu")
if not torch.cuda.is_available() and device == "cuda":
raise RuntimeError("CUDA device unavailable, please use Dockerfile.cpu instead.")
custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models")
if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"):
model_path = custom_model_path
print("Loading custom model from", model_path, flush=True)
else:
print("Loading default model", flush=True)
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
print("Downloading XTTS Model:", model_name, flush=True)
ModelManager().download_model(model_name)
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
print("XTTS Model downloaded", flush=True)
print("Loading XTTS", flush=True)
config = XttsConfig()
config.load_json(os.path.join(model_path, "config.json"))
model = Xtts.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=True if device == "cuda" else False)
model.to(device)
print("XTTS Loaded.", flush=True)
print("Running XTTS Server ...", flush=True)
# @app.post("/clone_speaker")
@spaces.GPU
def predict_speaker(wav_file):
"""Compute conditioning inputs from reference audio file."""
if isinstance(wav_file, str):
wav_file = open(wav_file,"rb");
temp_audio_name = next(tempfile._get_candidate_names())
with open(temp_audio_name, "wb") as temp, torch.inference_mode():
temp.write(io.BytesIO(wav_file.read()).getbuffer())
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
temp_audio_name
)
return {
"gpt_cond_latent": gpt_cond_latent.cpu().squeeze().half().tolist(),
"speaker_embedding": speaker_embedding.cpu().squeeze().half().tolist(),
}
def postprocess(wav):
"""Post process the output waveform"""
if isinstance(wav, list):
wav = torch.cat(wav, dim=0)
wav = wav.clone().detach().cpu().numpy()
wav = wav[None, : int(wav.shape[0])]
wav = np.clip(wav, -1, 1)
wav = (wav * 32767).astype(np.int16)
return wav
def encode_audio_common(
frame_input, encode_base64=True, sample_rate=24000, sample_width=2, channels=1
):
"""Return base64 encoded audio"""
wav_buf = io.BytesIO()
with wave.open(wav_buf, "wb") as vfout:
vfout.setnchannels(channels)
vfout.setsampwidth(sample_width)
vfout.setframerate(sample_rate)
vfout.writeframes(frame_input)
wav_buf.seek(0)
if encode_base64:
b64_encoded = base64.b64encode(wav_buf.getbuffer()).decode("utf-8")
return b64_encoded
else:
return wav_buf.read()
class StreamingInputs(BaseModel):
speaker_embedding: List[float]
gpt_cond_latent: List[List[float]]
text: str
language: str
add_wav_header: bool = True
stream_chunk_size: str = "20"
#
#def predict_streaming_generator(parsed_input: dict = Body(...)):
# speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
# gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
# text = parsed_input.text
# language = parsed_input.language
#
# stream_chunk_size = int(parsed_input.stream_chunk_size)
# add_wav_header = parsed_input.add_wav_header
#
#
# chunks = model.inference_stream(
# text,
# language,
# gpt_cond_latent,
# speaker_embedding,
# stream_chunk_size=stream_chunk_size,
# enable_text_splitting=True
# )
#
# for i, chunk in enumerate(chunks):
# chunk = postprocess(chunk)
# if i == 0 and add_wav_header:
# yield encode_audio_common(b"", encode_base64=False)
# yield chunk.tobytes()
# else:
# yield chunk.tobytes()
#
#
## @app.post("/tts_stream")
#def predict_streaming_endpoint(parsed_input: StreamingInputs):
# return StreamingResponse(
# predict_streaming_generator(parsed_input),
# media_type="audio/wav",
# )
class TTSInputs(BaseModel):
speaker_embedding: List[float]
gpt_cond_latent: List[List[float]]
text: str
language: str
temperature: float
speed: float
top_k: int
top_p: float
# @app.post("/tts")
@spaces.GPU
def predict_speech(parsed_input: TTSInputs):
speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
text = parsed_input.text
language = parsed_input.language
temperature = parsed_input.temperature
speed = parsed_input.speed
top_k = parsed_input.top_k
top_p = parsed_input.top_p
length_penalty = 1.0
repetition_penalty= 2.0
out = model.inference(
text,
language,
gpt_cond_latent,
speaker_embedding,
temperature,
length_penalty,
repetition_penalty,
top_k,
top_p,
speed,
)
wav = postprocess(torch.tensor(out["wav"]))
return encode_audio_common(wav.tobytes())
# @app.get("/studio_speakers")
def get_speakers():
if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"):
return {
speaker: {
"speaker_embedding": model.speaker_manager.speakers[speaker]["speaker_embedding"].cpu().squeeze().half().tolist(),
"gpt_cond_latent": model.speaker_manager.speakers[speaker]["gpt_cond_latent"].cpu().squeeze().half().tolist(),
}
for speaker in model.speaker_manager.speakers.keys()
}
else:
return {}
# @app.get("/languages")
def get_languages():
return config.languages