File size: 2,137 Bytes
dc5bc59
a168b4f
 
910d316
 
 
 
 
 
dc5bc59
a168b4f
dc5bc59
 
a168b4f
 
 
 
 
 
910d316
 
 
 
 
a168b4f
dc5bc59
 
 
a168b4f
910d316
 
 
 
 
 
 
 
a168b4f
 
910d316
a168b4f
 
910d316
 
 
 
 
 
 
 
a168b4f
 
 
 
910d316
 
a168b4f
 
910d316
a168b4f
910d316
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
import io
from starlette.responses import StreamingResponse

# Initialize the FastAPI app
app = FastAPI()

# Define a Pydantic model for the items
class Item(BaseModel):
    text: str
    name: str
    section: str

# Initialize ParlerTTS
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")

# A simple GET endpoint
@app.get("/")
def greet_json():
    return {"Hello": "World!"}

# Function to generate audio from text using ParlerTTS
def generate_audio(text, description="Neutral voice"):
    input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
    prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
    generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
    audio_arr = generation.cpu().numpy().squeeze()
    return audio_arr, model.config.sampling_rate

# A POST endpoint to receive and parse an array of JSON objects
@app.post("/")
async def create_items(items: List[Item]):
    processed_items = []
    for item in items:
        # Generate audio
        audio_arr, sample_rate = generate_audio(item.text)

        # Create in-memory bytes buffer for audio
        audio_bytes = io.BytesIO()
        sf.write(audio_bytes, audio_arr, sample_rate, format="WAV")
        audio_bytes.seek(0)  # Reset buffer position

        processed_item = {
            "text": item.text,
            "name": item.name,
            "section": item.section,
            "processed": True,
            "audio": StreamingResponse(audio_bytes, media_type="audio/wav")
        }
        processed_items.append(processed_item)

    return {"processed_items": processed_items}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="127.0.0.1", port=8000)