import io from fastapi import HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from modules.api import utils as api_utils from modules.api.Api import APIManager import soundfile as sf from modules import config from modules.normalization import text_normalize from modules.speaker import speaker_mgr from modules.synthesize_audio import synthesize_audio import logging from modules.utils.audio import apply_prosody_to_audio_data logger = logging.getLogger(__name__) class XTTS_V2_Settings: def __init__(self): self.stream_chunk_size = 100 self.temperature = 0.3 self.speed = 1 self.length_penalty = 0.5 self.repetition_penalty = 1.0 self.top_p = 0.7 self.top_k = 20 self.enable_text_splitting = True class TTSSettingsRequest(BaseModel): stream_chunk_size: int temperature: float speed: float length_penalty: float repetition_penalty: float top_p: float top_k: int enable_text_splitting: bool class SynthesisRequest(BaseModel): text: str speaker_wav: str language: str def setup(app: APIManager): XTTSV2 = XTTS_V2_Settings() @app.get("/v1/xtts_v2/speakers") async def speakers(): spks = speaker_mgr.list_speakers() return [ { "name": spk.name, "voice_id": spk.id, # TODO: 也许可以放一个 "/v1/tts" 接口地址在这里 "preview_url": "", } for spk in spks ] @app.post("/v1/xtts_v2/tts_to_audio", response_class=StreamingResponse) async def tts_to_audio(request: SynthesisRequest): text = request.text # speaker_wav 就是 speaker id 。。。 voice_id = request.speaker_wav language = request.language spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker( voice_id ) if spk is None: raise HTTPException(status_code=400, detail="Invalid speaker id") text = text_normalize(text, is_end=True) sample_rate, audio_data = synthesize_audio( text=text, temperature=XTTSV2.temperature, # length_penalty=XTTSV2.length_penalty, # repetition_penalty=XTTSV2.repetition_penalty, top_P=XTTSV2.top_p, top_K=XTTSV2.top_k, spk=spk, spliter_threshold=XTTSV2.stream_chunk_size, # TODO 支持设置 batch_size batch_size=4, end_of_sentence="[uv_break]", ) if XTTSV2.speed: audio_data = apply_prosody_to_audio_data( audio_data, rate=XTTSV2.speed, sr=sample_rate, ) # to mp3 buffer = io.BytesIO() sf.write(buffer, audio_data, sample_rate, format="wav") buffer.seek(0) buffer = api_utils.wav_to_mp3(buffer) return StreamingResponse(buffer, media_type="audio/mpeg") @app.get("/v1/xtts_v2/tts_stream") async def tts_stream(): raise HTTPException(status_code=501, detail="Not implemented") @app.post("/v1/xtts_v2/set_tts_settings") async def set_tts_settings(request: TTSSettingsRequest): try: if request.stream_chunk_size < 50: raise HTTPException( status_code=400, detail="stream_chunk_size must be greater than 0" ) if request.temperature < 0: raise HTTPException( status_code=400, detail="temperature must be greater than 0" ) if request.speed < 0: raise HTTPException( status_code=400, detail="speed must be greater than 0" ) if request.length_penalty < 0: raise HTTPException( status_code=400, detail="length_penalty must be greater than 0" ) if request.repetition_penalty < 0: raise HTTPException( status_code=400, detail="repetition_penalty must be greater than 0" ) if request.top_p < 0: raise HTTPException( status_code=400, detail="top_p must be greater than 0" ) if request.top_k < 0: raise HTTPException( status_code=400, detail="top_k must be greater than 0" ) XTTSV2.stream_chunk_size = request.stream_chunk_size XTTSV2.temperature = request.temperature XTTSV2.speed = request.speed XTTSV2.length_penalty = request.length_penalty XTTSV2.repetition_penalty = request.repetition_penalty XTTSV2.top_p = request.top_p XTTSV2.top_k = request.top_k XTTSV2.enable_text_splitting = request.enable_text_splitting return {"message": "Settings successfully applied"} except Exception as e: if isinstance(e, HTTPException): raise e logger.error(e) raise HTTPException(status_code=500, detail=str(e))