parler / app.py
Carsten Høyer
add print
5dfce18
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
import io
from starlette.responses import StreamingResponse
# Initialize the FastAPI app
app = FastAPI()
# Define a Pydantic model for the items
class Item(BaseModel):
text: str
name: str
section: str
# Initialize ParlerTTS
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
# A simple GET endpoint
@app.get("/")
def greet_json():
return {"Hello": "World!"}
# Function to generate audio from text using ParlerTTS
def generate_audio(text, description="Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."):
print("A")
input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
print("B")
prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
print("C")
generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
print("D")
audio_arr = generation.cpu().numpy().squeeze()
print("E")
return audio_arr, model.config.sampling_rate
# A POST endpoint to receive and parse an array of JSON objects
@app.post("/")
async def create_items(items: List[Item]):
processed_items = []
for item in items:
print(f"Processing item: {item.text}")
# Generate audio
print("before")
audio_arr, sample_rate = generate_audio(item.text)
print("after")
# # Create in-memory bytes buffer for audio
# audio_bytes = io.BytesIO()
# sf.write(audio_bytes, audio_arr, sample_rate, format="WAV")
# audio_bytes.seek(0) # Reset buffer position
processed_item = {
"text": item.text,
"name": item.name,
"section": item.section,
"processed": True,
# "audio": StreamingResponse(audio_bytes, media_type="audio/wav")
}
processed_items.append(processed_item)
return {"processed_items": processed_items}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)