ChatTTS-Forge / modules /api /impl /google_api.py
zhzluke96
update
d2b7e94
raw
history blame
5.6 kB
from typing import Union
from fastapi import HTTPException
from pydantic import BaseModel
from modules.api import utils as api_utils
from modules.api.Api import APIManager
from modules.api.impl.handler.SSMLHandler import SSMLHandler
from modules.api.impl.handler.TTSHandler import TTSHandler
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
from modules.api.impl.model.enhancer_model import EnhancerConfig
from modules.speaker import Speaker, speaker_mgr
class SynthesisInput(BaseModel):
text: Union[str, None] = None
ssml: Union[str, None] = None
class VoiceSelectionParams(BaseModel):
languageCode: str = "ZH-CN"
name: str = "female2"
style: str = ""
temperature: float = 0.3
topP: float = 0.7
topK: int = 20
seed: int = 42
# end_of_sentence
eos: str = "[uv_break]"
class AudioConfig(BaseModel):
audioEncoding: AudioFormat = AudioFormat.mp3
speakingRate: float = 1
pitch: float = 0
volumeGainDb: float = 0
sampleRateHertz: int = 24000
batchSize: int = 4
spliterThreshold: int = 100
class GoogleTextSynthesizeRequest(BaseModel):
input: SynthesisInput
voice: VoiceSelectionParams
audioConfig: AudioConfig
enhancerConfig: EnhancerConfig = None
class GoogleTextSynthesizeResponse(BaseModel):
audioContent: str
async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
input = request.input
voice = request.voice
audioConfig = request.audioConfig
enhancerConfig = request.enhancerConfig
# 提取参数
# TODO 这个也许应该传给 normalizer
language_code = voice.languageCode
voice_name = voice.name
infer_seed = voice.seed or 42
eos = voice.eos or "[uv_break]"
audio_format = audioConfig.audioEncoding
if not isinstance(audio_format, AudioFormat) and isinstance(audio_format, str):
audio_format = AudioFormat(audio_format)
speaking_rate = audioConfig.speakingRate or 1
pitch = audioConfig.pitch or 0
volume_gain_db = audioConfig.volumeGainDb or 0
batch_size = audioConfig.batchSize or 1
spliter_threshold = audioConfig.spliterThreshold or 100
# TODO
sample_rate = audioConfig.sampleRateHertz or 24000
params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
# 虽然 calc_spk_style 可以解析 seed 形式,但是这个接口只准备支持 speakers list 中存在的 speaker
if speaker_mgr.get_speaker(voice_name) is None:
raise HTTPException(
status_code=422, detail="The specified voice name is not supported."
)
if not isinstance(params.get("spk"), Speaker):
raise HTTPException(
status_code=422, detail="The specified voice name is not supported."
)
speaker = params.get("spk")
tts_config = ChatTTSConfig(
style=params.get("style", ""),
temperature=voice.temperature,
top_k=voice.topK,
top_p=voice.topP,
)
infer_config = InferConfig(
batch_size=batch_size,
spliter_threshold=spliter_threshold,
eos=eos,
seed=infer_seed,
)
adjust_config = AdjustConfig(
speaking_rate=speaking_rate,
pitch=pitch,
volume_gain_db=volume_gain_db,
)
enhancer_config = enhancerConfig
mime_type = f"audio/{audio_format.value}"
if audio_format == AudioFormat.mp3:
mime_type = "audio/mpeg"
try:
if input.text:
text_content = input.text
handler = TTSHandler(
text_content=text_content,
spk=speaker,
tts_config=tts_config,
infer_config=infer_config,
adjust_config=adjust_config,
enhancer_config=enhancer_config,
)
base64_string = handler.enqueue_to_base64(format=audio_format)
return {"audioContent": f"data:{mime_type};base64,{base64_string}"}
elif input.ssml:
ssml_content = input.ssml
handler = SSMLHandler(
ssml_content=ssml_content,
infer_config=infer_config,
adjust_config=adjust_config,
enhancer_config=enhancer_config,
)
base64_string = handler.enqueue_to_base64(format=audio_format)
return {"audioContent": f"data:{mime_type};base64,{base64_string}"}
else:
raise HTTPException(
status_code=422, detail="Invalid input text or ssml specified."
)
except Exception as e:
import logging
logging.exception(e)
if isinstance(e, HTTPException):
raise e
else:
raise HTTPException(status_code=500, detail=str(e))
def setup(app: APIManager):
app.post(
"/v1/text:synthesize",
response_model=GoogleTextSynthesizeResponse,
description="""
google api document: <br/>
[https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize](https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize)
- 多个属性在本系统中无用仅仅是为了兼容google api
- voice 中的 topP, topK, temperature 为本系统中的参数
- voice.name 即 speaker name (或者speaker seed)
- voice.seed 为 infer seed (可在webui中测试具体作用)
- 编码格式影响的是 audioContent 的二进制格式,所以所有format都是返回带有base64数据的json
""",
)(google_text_synthesize)