Carsten Høyer commited on
Commit
910d316
1 Parent(s): a168b4f

add parler

Browse files
Files changed (2) hide show
  1. app.py +35 -5
  2. requirements.txt +4 -0
app.py CHANGED
@@ -1,6 +1,12 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from typing import List
 
 
 
 
 
 
4
 
5
  # Initialize the FastAPI app
6
  app = FastAPI()
@@ -11,24 +17,48 @@ class Item(BaseModel):
11
  name: str
12
  section: str
13
 
 
 
 
 
 
14
  # A simple GET endpoint
15
  @app.get("/")
16
  def greet_json():
17
  return {"Hello": "World!"}
18
 
 
 
 
 
 
 
 
 
19
  # A POST endpoint to receive and parse an array of JSON objects
20
  @app.post("/")
21
- def create_items(items: List[Item]):
22
- # Process each item in the list
23
  processed_items = []
24
  for item in items:
25
- # Here you could perform any processing you need on each item
 
 
 
 
 
 
 
26
  processed_item = {
27
  "text": item.text,
28
  "name": item.name,
29
  "section": item.section,
30
- "processed": True # Example of adding a field to indicate processing
 
31
  }
32
  processed_items.append(processed_item)
33
-
34
  return {"processed_items": processed_items}
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from typing import List
4
+ import torch
5
+ from parler_tts import ParlerTTSForConditionalGeneration
6
+ from transformers import AutoTokenizer
7
+ import soundfile as sf
8
+ import io
9
+ from starlette.responses import StreamingResponse
10
 
11
  # Initialize the FastAPI app
12
  app = FastAPI()
 
17
  name: str
18
  section: str
19
 
20
+ # Initialize ParlerTTS
21
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
22
+ model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
23
+ tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
24
+
25
  # A simple GET endpoint
26
  @app.get("/")
27
  def greet_json():
28
  return {"Hello": "World!"}
29
 
30
+ # Function to generate audio from text using ParlerTTS
31
+ def generate_audio(text, description="Neutral voice"):
32
+ input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
33
+ prompt_input_ids = tokenizer(text, return_tensors="pt").input_ids.to(device)
34
+ generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
35
+ audio_arr = generation.cpu().numpy().squeeze()
36
+ return audio_arr, model.config.sampling_rate
37
+
38
  # A POST endpoint to receive and parse an array of JSON objects
39
  @app.post("/")
40
+ async def create_items(items: List[Item]):
 
41
  processed_items = []
42
  for item in items:
43
+ # Generate audio
44
+ audio_arr, sample_rate = generate_audio(item.text)
45
+
46
+ # Create in-memory bytes buffer for audio
47
+ audio_bytes = io.BytesIO()
48
+ sf.write(audio_bytes, audio_arr, sample_rate, format="WAV")
49
+ audio_bytes.seek(0) # Reset buffer position
50
+
51
  processed_item = {
52
  "text": item.text,
53
  "name": item.name,
54
  "section": item.section,
55
+ "processed": True,
56
+ "audio": StreamingResponse(audio_bytes, media_type="audio/wav")
57
  }
58
  processed_items.append(processed_item)
59
+
60
  return {"processed_items": processed_items}
61
+
62
+ if __name__ == "__main__":
63
+ import uvicorn
64
+ uvicorn.run(app, host="127.0.0.1", port=8000)
requirements.txt CHANGED
@@ -1,2 +1,6 @@
1
  fastapi
2
  uvicorn[standard]
 
 
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ torch
4
+ transformers
5
+ parler-tts
6
+ soundfile