Spaces:
Sleeping
Sleeping
import io | |
import os | |
import tempfile | |
from typing import List, Optional | |
import TTS.api | |
import torch | |
from pydub import AudioSegment | |
from fastapi import FastAPI, File, Form, UploadFile, HTTPException | |
from fastapi.responses import StreamingResponse, Response | |
import config | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
models = {} | |
for id, model in config.models.items(): | |
models[id] = TTS.api.TTS(model).to(device) | |
class SynthesizeResponse(Response): | |
media_type = 'audio/wav' | |
app = FastAPI() | |
async def synthesize( | |
text: str = Form('Hello,World!'), | |
speaker_wavs: List[UploadFile] = File(None), | |
speaker_idx: str = Form('Ana Florence'), | |
language: str = Form('ja'), | |
temperature: float = Form(0.65), | |
length_penalty: float = Form(1.0), | |
repetition_penalty: float = Form(2.0), | |
top_k: int = Form(50), | |
top_p: float = Form(0.8), | |
speed: float = Form(1.0), | |
enable_text_splitting: bool = Form(True) | |
) -> StreamingResponse: | |
temp_files = [] | |
try: | |
if speaker_wavs: | |
# Process each uploaded file | |
for speaker_wav in speaker_wavs: | |
speaker_wav_bytes = await speaker_wav.read() | |
# Convert the uploaded audio file to a WAV format using pydub | |
try: | |
audio = AudioSegment.from_file(io.BytesIO(speaker_wav_bytes)) | |
wav_buffer = io.BytesIO() | |
audio.export(wav_buffer, format="wav") | |
wav_buffer.seek(0) # Reset buffer position to the beginning | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}") | |
temp_wav_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) | |
temp_wav_file.write(wav_buffer.read()) | |
temp_wav_file.close() | |
temp_files.append(temp_wav_file.name) | |
output_buffer = io.BytesIO() | |
if temp_files: | |
models['multi'].tts_to_file( | |
text=text, | |
speaker_wav=temp_files, | |
language=language, | |
file_path=output_buffer, | |
temperature=temperature, | |
length_penalty=length_penalty, | |
repetition_penalty=repetition_penalty, | |
top_k=top_k, | |
top_p=top_p, | |
speed=speed, | |
enable_text_splitting=enable_text_splitting | |
) | |
else: | |
models['multi'].tts_to_file( | |
text=text, | |
speaker=speaker_idx, | |
language=language, | |
file_path=output_buffer, | |
temperature=temperature, | |
length_penalty=length_penalty, | |
repetition_penalty=repetition_penalty, | |
top_k=top_k, | |
top_p=top_p, | |
speed=speed, | |
enable_text_splitting=enable_text_splitting | |
) | |
output_buffer.seek(0) | |
return StreamingResponse(output_buffer, media_type="audio/wav") | |
finally: | |
for temp_file in temp_files: | |
if isinstance(temp_file, str) and os.path.exists(temp_file): | |
os.remove(temp_file) |