ChatTTS-Forge / modules /api /impl /xtts_v2_api.py
zhzluke96
update
1df74c6
raw
history blame
5.24 kB
import io
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from modules.api import utils as api_utils
from modules.api.Api import APIManager
import soundfile as sf
from modules import config
from modules.normalization import text_normalize
from modules.speaker import speaker_mgr
from modules.synthesize_audio import synthesize_audio
import logging
from modules.utils.audio import apply_prosody_to_audio_data
logger = logging.getLogger(__name__)
class XTTS_V2_Settings:
def __init__(self):
self.stream_chunk_size = 100
self.temperature = 0.3
self.speed = 1
self.length_penalty = 0.5
self.repetition_penalty = 1.0
self.top_p = 0.7
self.top_k = 20
self.enable_text_splitting = True
class TTSSettingsRequest(BaseModel):
stream_chunk_size: int
temperature: float
speed: float
length_penalty: float
repetition_penalty: float
top_p: float
top_k: int
enable_text_splitting: bool
class SynthesisRequest(BaseModel):
text: str
speaker_wav: str
language: str
def setup(app: APIManager):
XTTSV2 = XTTS_V2_Settings()
@app.get("/v1/xtts_v2/speakers")
async def speakers():
spks = speaker_mgr.list_speakers()
return [
{
"name": spk.name,
"voice_id": spk.id,
# TODO: 也许可以放一个 "/v1/tts" 接口地址在这里
"preview_url": "",
}
for spk in spks
]
@app.post("/v1/xtts_v2/tts_to_audio", response_class=StreamingResponse)
async def tts_to_audio(request: SynthesisRequest):
text = request.text
# speaker_wav 就是 speaker id 。。。
voice_id = request.speaker_wav
language = request.language
spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker(
voice_id
)
if spk is None:
raise HTTPException(status_code=400, detail="Invalid speaker id")
text = text_normalize(text, is_end=True)
sample_rate, audio_data = synthesize_audio(
text=text,
temperature=XTTSV2.temperature,
# length_penalty=XTTSV2.length_penalty,
# repetition_penalty=XTTSV2.repetition_penalty,
top_P=XTTSV2.top_p,
top_K=XTTSV2.top_k,
spk=spk,
spliter_threshold=XTTSV2.stream_chunk_size,
# TODO 支持设置 batch_size
batch_size=4,
end_of_sentence="[uv_break]",
)
if XTTSV2.speed:
audio_data = apply_prosody_to_audio_data(
audio_data,
rate=XTTSV2.speed,
sr=sample_rate,
)
# to mp3
buffer = io.BytesIO()
sf.write(buffer, audio_data, sample_rate, format="wav")
buffer.seek(0)
buffer = api_utils.wav_to_mp3(buffer)
return StreamingResponse(buffer, media_type="audio/mpeg")
@app.get("/v1/xtts_v2/tts_stream")
async def tts_stream():
raise HTTPException(status_code=501, detail="Not implemented")
@app.post("/v1/xtts_v2/set_tts_settings")
async def set_tts_settings(request: TTSSettingsRequest):
try:
if request.stream_chunk_size < 50:
raise HTTPException(
status_code=400, detail="stream_chunk_size must be greater than 0"
)
if request.temperature < 0:
raise HTTPException(
status_code=400, detail="temperature must be greater than 0"
)
if request.speed < 0:
raise HTTPException(
status_code=400, detail="speed must be greater than 0"
)
if request.length_penalty < 0:
raise HTTPException(
status_code=400, detail="length_penalty must be greater than 0"
)
if request.repetition_penalty < 0:
raise HTTPException(
status_code=400, detail="repetition_penalty must be greater than 0"
)
if request.top_p < 0:
raise HTTPException(
status_code=400, detail="top_p must be greater than 0"
)
if request.top_k < 0:
raise HTTPException(
status_code=400, detail="top_k must be greater than 0"
)
XTTSV2.stream_chunk_size = request.stream_chunk_size
XTTSV2.temperature = request.temperature
XTTSV2.speed = request.speed
XTTSV2.length_penalty = request.length_penalty
XTTSV2.repetition_penalty = request.repetition_penalty
XTTSV2.top_p = request.top_p
XTTSV2.top_k = request.top_k
XTTSV2.enable_text_splitting = request.enable_text_splitting
return {"message": "Settings successfully applied"}
except Exception as e:
if isinstance(e, HTTPException):
raise e
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))