ChatTTS-Forge / modules /api /impl /google_api.py
zhzluke96
update
1df74c6
raw
history blame
7 kB
import base64
from typing import Literal
from fastapi import HTTPException
import io
import soundfile as sf
from pydantic import BaseModel
from modules.Enhancer.ResembleEnhance import (
apply_audio_enhance,
apply_audio_enhance_full,
)
from modules.api.Api import APIManager
from modules.synthesize_audio import synthesize_audio
from modules.utils import audio
from modules.utils.audio import apply_prosody_to_audio_data
from modules.normalization import text_normalize
from modules import generate_audio as generate
from modules.speaker import speaker_mgr
from modules.ssml_parser.SSMLParser import create_ssml_parser
from modules.SynthesizeSegments import (
SynthesizeSegments,
combine_audio_segments,
)
from modules.api import utils as api_utils
class SynthesisInput(BaseModel):
text: str = ""
ssml: str = ""
class VoiceSelectionParams(BaseModel):
languageCode: str = "ZH-CN"
name: str = "female2"
style: str = ""
temperature: float = 0.3
topP: float = 0.7
topK: int = 20
seed: int = 42
# end_of_sentence
eos: str = "[uv_break]"
class AudioConfig(BaseModel):
audioEncoding: api_utils.AudioFormat = "mp3"
speakingRate: float = 1
pitch: float = 0
volumeGainDb: float = 0
sampleRateHertz: int = 24000
batchSize: int = 1
spliterThreshold: int = 100
class EnhancerConfig(BaseModel):
enabled: bool = False
model: str = "resemble-enhance"
nfe: int = 32
solver: Literal["midpoint", "rk4", "euler"] = "midpoint"
lambd: float = 0.5
tau: float = 0.5
class GoogleTextSynthesizeRequest(BaseModel):
input: SynthesisInput
voice: VoiceSelectionParams
audioConfig: AudioConfig
enhancerConfig: EnhancerConfig = None
class GoogleTextSynthesizeResponse(BaseModel):
audioContent: str
async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
input = request.input
voice = request.voice
audioConfig = request.audioConfig
enhancerConfig = request.enhancerConfig
# 提取参数
# TODO 这个也许应该传给 normalizer
language_code = voice.languageCode
voice_name = voice.name
infer_seed = voice.seed or 42
eos = voice.eos or "[uv_break]"
audio_format = audioConfig.audioEncoding or "mp3"
speaking_rate = audioConfig.speakingRate or 1
pitch = audioConfig.pitch or 0
volume_gain_db = audioConfig.volumeGainDb or 0
batch_size = audioConfig.batchSize or 1
spliter_threshold = audioConfig.spliterThreshold or 100
sample_rate = audioConfig.sampleRateHertz or 24000
params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
# 虽然 calc_spk_style 可以解析 seed 形式,但是这个接口只准备支持 speakers list 中存在的 speaker
if speaker_mgr.get_speaker(voice_name) is None:
raise HTTPException(
status_code=422, detail="The specified voice name is not supported."
)
if audio_format != "mp3" and audio_format != "wav":
raise HTTPException(
status_code=422, detail="Invalid audio encoding format specified."
)
if enhancerConfig.enabled:
# TODO enhancer params checker
pass
try:
if input.text:
# 处理文本合成逻辑
text = text_normalize(input.text, is_end=True)
sample_rate, audio_data = synthesize_audio(
text,
temperature=(
voice.temperature
if voice.temperature
else params.get("temperature", 0.3)
),
top_P=voice.topP if voice.topP else params.get("top_p", 0.7),
top_K=voice.topK if voice.topK else params.get("top_k", 20),
spk=params.get("spk", -1),
infer_seed=infer_seed,
prompt1=params.get("prompt1", ""),
prompt2=params.get("prompt2", ""),
prefix=params.get("prefix", ""),
batch_size=batch_size,
spliter_threshold=spliter_threshold,
end_of_sentence=eos,
)
elif input.ssml:
parser = create_ssml_parser()
segments = parser.parse(input.ssml)
for seg in segments:
seg["text"] = text_normalize(seg["text"], is_end=True)
if len(segments) == 0:
raise HTTPException(
status_code=422, detail="The SSML text is empty or parsing failed."
)
synthesize = SynthesizeSegments(
batch_size=batch_size, eos=eos, spliter_thr=spliter_threshold
)
audio_segments = synthesize.synthesize_segments(segments)
combined_audio = combine_audio_segments(audio_segments)
sample_rate, audio_data = audio.pydub_to_np(combined_audio)
else:
raise HTTPException(
status_code=422, detail="Either text or SSML input must be provided."
)
if enhancerConfig.enabled:
audio_data, sample_rate = apply_audio_enhance_full(
audio_data=audio_data,
sr=sample_rate,
nfe=enhancerConfig.nfe,
solver=enhancerConfig.solver,
lambd=enhancerConfig.lambd,
tau=enhancerConfig.tau,
)
audio_data = apply_prosody_to_audio_data(
audio_data,
rate=speaking_rate,
pitch=pitch,
volume=volume_gain_db,
sr=sample_rate,
)
buffer = io.BytesIO()
sf.write(buffer, audio_data, sample_rate, format="wav")
buffer.seek(0)
if audio_format == "mp3":
buffer = api_utils.wav_to_mp3(buffer)
base64_encoded = base64.b64encode(buffer.read())
base64_string = base64_encoded.decode("utf-8")
return {
"audioContent": f"data:audio/{audio_format.lower()};base64,{base64_string}"
}
except Exception as e:
import logging
logging.exception(e)
if isinstance(e, HTTPException):
raise e
else:
raise HTTPException(status_code=500, detail=str(e))
def setup(app: APIManager):
app.post(
"/v1/text:synthesize",
response_model=GoogleTextSynthesizeResponse,
description="""
google api document: <br/>
[https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize](https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize)
- 多个属性在本系统中无用仅仅是为了兼容google api
- voice 中的 topP, topK, temperature 为本系统中的参数
- voice.name 即 speaker name (或者speaker seed)
- voice.seed 为 infer seed (可在webui中测试具体作用)
- 编码格式影响的是 audioContent 的二进制格式,所以所有format都是返回带有base64数据的json
""",
)(google_text_synthesize)