from fastapi import FastAPI from pydantic import BaseModel from typing import List import torch from parler_tts import ParlerTTSForConditionalGeneration from transformers import AutoTokenizer import soundfile as sf import io from starlette.responses import StreamingResponse # Initialize the FastAPI app app = FastAPI() # Define a Pydantic model for the items class Item(BaseModel): text: str name: str section: str # Initialize ParlerTTS device = "cuda:0" if torch.cuda.is_available() else "cpu" model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device) tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1") # A simple GET endpoint @app.get("/") def greet_json(): return {"Hello": "World!"} # Function to generate audio from text using ParlerTTS def generate_audio(text, description="Neutral voice"): input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device) prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device) generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) audio_arr = generation.cpu().numpy().squeeze() return audio_arr, model.config.sampling_rate # A POST endpoint to receive and parse an array of JSON objects @app.post("/") async def create_items(items: List[Item]): processed_items = [] for item in items: # Generate audio audio_arr, sample_rate = generate_audio(item.text) # Create in-memory bytes buffer for audio audio_bytes = io.BytesIO() sf.write(audio_bytes, audio_arr, sample_rate, format="WAV") audio_bytes.seek(0) # Reset buffer position processed_item = { "text": item.text, "name": item.name, "section": item.section, "processed": True, "audio": StreamingResponse(audio_bytes, media_type="audio/wav") } processed_items.append(processed_item) return {"processed_items": processed_items} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="127.0.0.1", port=8000)