tts-xtts2-multi / app.py
TaiYouWeb's picture
Upload 5 files
5ca847f verified
raw
history blame
3.3 kB
import io
import os
import tempfile
from typing import List, Optional
import TTS.api
import torch
from pydub import AudioSegment
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import StreamingResponse, Response
import config
device = "cuda" if torch.cuda.is_available() else "cpu"
models = {}
for id, model in config.models.items():
models[id] = TTS.api.TTS(model).to(device)
class SynthesizeResponse(Response):
media_type = 'audio/wav'
app = FastAPI()
@app.post('/tts', response_class=SynthesizeResponse)
async def synthesize(
text: str = Form('Hello,World!'),
speaker_wavs: List[UploadFile] = File(None),
speaker_idx: str = Form('Ana Florence'),
language: str = Form('ja'),
temperature: float = Form(0.65),
length_penalty: float = Form(1.0),
repetition_penalty: float = Form(2.0),
top_k: int = Form(50),
top_p: float = Form(0.8),
speed: float = Form(1.0),
enable_text_splitting: bool = Form(True)
) -> StreamingResponse:
temp_files = []
try:
if speaker_wavs:
# Process each uploaded file
for speaker_wav in speaker_wavs:
speaker_wav_bytes = await speaker_wav.read()
# Convert the uploaded audio file to a WAV format using pydub
try:
audio = AudioSegment.from_file(io.BytesIO(speaker_wav_bytes))
wav_buffer = io.BytesIO()
audio.export(wav_buffer, format="wav")
wav_buffer.seek(0) # Reset buffer position to the beginning
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}")
temp_wav_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
temp_wav_file.write(wav_buffer.read())
temp_wav_file.close()
temp_files.append(temp_wav_file.name)
output_buffer = io.BytesIO()
if temp_files:
models['multi'].tts_to_file(
text=text,
speaker_wav=temp_files,
language=language,
file_path=output_buffer,
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
speed=speed,
enable_text_splitting=enable_text_splitting
)
else:
models['multi'].tts_to_file(
text=text,
speaker=speaker_idx,
language=language,
file_path=output_buffer,
temperature=temperature,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
speed=speed,
enable_text_splitting=enable_text_splitting
)
output_buffer.seek(0)
return StreamingResponse(output_buffer, media_type="audio/wav")
finally:
for temp_file in temp_files:
if isinstance(temp_file, str) and os.path.exists(temp_file):
os.remove(temp_file)