import base64 import io import os import tempfile import wave import torch import numpy as np from typing import List from pydantic import BaseModel import spaces from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from trainer.io import get_user_data_dir from TTS.utils.manage import ModelManager os.environ["COQUI_TOS_AGREED"] = "1" torch.set_num_threads(int(os.environ.get("NUM_THREADS", os.cpu_count()))) device = torch.device("cuda" if os.environ.get("USE_CPU", "0") == "0" else "cpu") if not torch.cuda.is_available() and device == "cuda": raise RuntimeError("CUDA device unavailable, please use Dockerfile.cpu instead.") custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models") if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"): model_path = custom_model_path print("Loading custom model from", model_path, flush=True) else: print("Loading default model", flush=True) model_name = "tts_models/multilingual/multi-dataset/xtts_v2" print("Downloading XTTS Model:", model_name, flush=True) ModelManager().download_model(model_name) model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--")) print("XTTS Model downloaded", flush=True) print("Loading XTTS", flush=True) config = XttsConfig() config.load_json(os.path.join(model_path, "config.json")) model = Xtts.init_from_config(config) model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=True if device == "cuda" else False) model.to(device) print("XTTS Loaded.", flush=True) print("Running XTTS Server ...", flush=True) # @app.post("/clone_speaker") @spaces.GPU def predict_speaker(wav_file): """Compute conditioning inputs from reference audio file.""" if isinstance(wav_file, str): wav_file = open(wav_file,"rb"); temp_audio_name = next(tempfile._get_candidate_names()) with open(temp_audio_name, "wb") as temp, torch.inference_mode(): temp.write(io.BytesIO(wav_file.read()).getbuffer()) gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( temp_audio_name ) return { "gpt_cond_latent": gpt_cond_latent.cpu().squeeze().half().tolist(), "speaker_embedding": speaker_embedding.cpu().squeeze().half().tolist(), } def postprocess(wav): """Post process the output waveform""" if isinstance(wav, list): wav = torch.cat(wav, dim=0) wav = wav.clone().detach().cpu().numpy() wav = wav[None, : int(wav.shape[0])] wav = np.clip(wav, -1, 1) wav = (wav * 32767).astype(np.int16) return wav def encode_audio_common( frame_input, encode_base64=True, sample_rate=24000, sample_width=2, channels=1 ): """Return base64 encoded audio""" wav_buf = io.BytesIO() with wave.open(wav_buf, "wb") as vfout: vfout.setnchannels(channels) vfout.setsampwidth(sample_width) vfout.setframerate(sample_rate) vfout.writeframes(frame_input) wav_buf.seek(0) if encode_base64: b64_encoded = base64.b64encode(wav_buf.getbuffer()).decode("utf-8") return b64_encoded else: return wav_buf.read() class StreamingInputs(BaseModel): speaker_embedding: List[float] gpt_cond_latent: List[List[float]] text: str language: str add_wav_header: bool = True stream_chunk_size: str = "20" # #def predict_streaming_generator(parsed_input: dict = Body(...)): # speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) # gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) # text = parsed_input.text # language = parsed_input.language # # stream_chunk_size = int(parsed_input.stream_chunk_size) # add_wav_header = parsed_input.add_wav_header # # # chunks = model.inference_stream( # text, # language, # gpt_cond_latent, # speaker_embedding, # stream_chunk_size=stream_chunk_size, # enable_text_splitting=True # ) # # for i, chunk in enumerate(chunks): # chunk = postprocess(chunk) # if i == 0 and add_wav_header: # yield encode_audio_common(b"", encode_base64=False) # yield chunk.tobytes() # else: # yield chunk.tobytes() # # ## @app.post("/tts_stream") #def predict_streaming_endpoint(parsed_input: StreamingInputs): # return StreamingResponse( # predict_streaming_generator(parsed_input), # media_type="audio/wav", # ) class TTSInputs(BaseModel): speaker_embedding: List[float] gpt_cond_latent: List[List[float]] text: str language: str temperature: float speed: float top_k: int top_p: float # @app.post("/tts") @spaces.GPU def predict_speech(parsed_input: TTSInputs): speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) text = parsed_input.text language = parsed_input.language temperature = parsed_input.temperature speed = parsed_input.speed top_k = parsed_input.top_k top_p = parsed_input.top_p length_penalty = 1.0 repetition_penalty= 2.0 out = model.inference( text, language, gpt_cond_latent, speaker_embedding, temperature, length_penalty, repetition_penalty, top_k, top_p, speed, ) wav = postprocess(torch.tensor(out["wav"])) return encode_audio_common(wav.tobytes()) # @app.get("/studio_speakers") def get_speakers(): if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"): return { speaker: { "speaker_embedding": model.speaker_manager.speakers[speaker]["speaker_embedding"].cpu().squeeze().half().tolist(), "gpt_cond_latent": model.speaker_manager.speakers[speaker]["gpt_cond_latent"].cpu().squeeze().half().tolist(), } for speaker in model.speaker_manager.speakers.keys() } else: return {} # @app.get("/languages") def get_languages(): return config.languages