Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import io
|
4 |
+
from pydub import AudioSegment
|
5 |
+
import tempfile
|
6 |
+
import requests
|
7 |
+
import time
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from threading import Lock
|
10 |
+
import base64
|
11 |
+
import uuid
|
12 |
+
import os
|
13 |
+
import json
|
14 |
+
import sseclient
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class AppState:
|
18 |
+
stream: np.ndarray | None = None
|
19 |
+
sampling_rate: int = 0
|
20 |
+
conversation: list = field(default_factory=list)
|
21 |
+
api_key: str = os.getenv("API_KEY", "")
|
22 |
+
output_format: str = "mp3"
|
23 |
+
url: str = "https://audio.herm.studio/v1/chat/completions"
|
24 |
+
|
25 |
+
# Global lock for thread safety
|
26 |
+
state_lock = Lock()
|
27 |
+
|
28 |
+
def process_audio(audio: tuple, state: AppState):
|
29 |
+
if state.stream is None:
|
30 |
+
state.stream = audio[1]
|
31 |
+
state.sampling_rate = audio[0]
|
32 |
+
else:
|
33 |
+
state.stream = np.concatenate((state.stream, audio[1]))
|
34 |
+
return state
|
35 |
+
|
36 |
+
def update_or_append_conversation(conversation, id, role, new_content):
|
37 |
+
for entry in conversation:
|
38 |
+
if entry["id"] == id and entry["role"] == role:
|
39 |
+
entry["content"] = new_content
|
40 |
+
return
|
41 |
+
conversation.append({"id": id, "role": role, "content": new_content})
|
42 |
+
|
43 |
+
def generate_response_and_audio(audio_bytes: bytes, state: AppState):
|
44 |
+
if not state.api_key:
|
45 |
+
raise gr.Error("Please enter a valid API key first.")
|
46 |
+
|
47 |
+
headers = {
|
48 |
+
"X-API-Key": state.api_key,
|
49 |
+
"Content-Type": "application/json"
|
50 |
+
}
|
51 |
+
|
52 |
+
audio_data = base64.b64encode(audio_bytes).decode()
|
53 |
+
old_messages = [{"role": item["role"], "content": item["content"]} for item in state.conversation]
|
54 |
+
old_messages.append({"role": "user", "content": [{"type": "audio", "data": audio_data}]})
|
55 |
+
|
56 |
+
data = {
|
57 |
+
"messages": old_messages,
|
58 |
+
"stream": True,
|
59 |
+
"max_tokens": 256
|
60 |
+
}
|
61 |
+
|
62 |
+
try:
|
63 |
+
response = requests.post(state.url, headers=headers, json=data, stream=True)
|
64 |
+
response.raise_for_status()
|
65 |
+
|
66 |
+
if response.status_code != 200:
|
67 |
+
raise gr.Error(f"API returned status code {response.status_code}")
|
68 |
+
|
69 |
+
client = sseclient.SSEClient(response)
|
70 |
+
|
71 |
+
full_response = ""
|
72 |
+
asr_result = ""
|
73 |
+
audio_chunks = []
|
74 |
+
id = uuid.uuid4()
|
75 |
+
|
76 |
+
for event in client.events():
|
77 |
+
if event.data == "[DONE]":
|
78 |
+
break
|
79 |
+
|
80 |
+
try:
|
81 |
+
chunk = json.loads(event.data)
|
82 |
+
except json.JSONDecodeError:
|
83 |
+
continue
|
84 |
+
|
85 |
+
if 'choices' not in chunk or not chunk['choices']:
|
86 |
+
continue
|
87 |
+
|
88 |
+
choice = chunk['choices'][0]
|
89 |
+
|
90 |
+
if 'delta' in choice and 'content' in choice['delta']:
|
91 |
+
content = choice['delta'].get('content')
|
92 |
+
if content is not None:
|
93 |
+
full_response += content
|
94 |
+
yield id, full_response, asr_result, None, state
|
95 |
+
|
96 |
+
if 'asr_results' in choice:
|
97 |
+
asr_result = "".join(choice['asr_results'])
|
98 |
+
yield id, full_response, asr_result, None, state
|
99 |
+
|
100 |
+
if 'audio' in choice:
|
101 |
+
if choice['audio'] is not None:
|
102 |
+
audio_chunks.extend(choice['audio'])
|
103 |
+
|
104 |
+
if audio_chunks:
|
105 |
+
try:
|
106 |
+
final_audio = b"".join([base64.b64decode(a) for a in audio_chunks])
|
107 |
+
yield id, full_response, asr_result, final_audio, state
|
108 |
+
except TypeError:
|
109 |
+
pass
|
110 |
+
|
111 |
+
if not full_response and not asr_result and not audio_chunks:
|
112 |
+
raise gr.Error("No valid response received from the API")
|
113 |
+
|
114 |
+
except requests.exceptions.RequestException as e:
|
115 |
+
raise gr.Error(f"Request failed: {str(e)}")
|
116 |
+
except Exception as e:
|
117 |
+
raise gr.Error(f"Error during audio streaming: {str(e)}")
|
118 |
+
|
119 |
+
def response(state: AppState):
|
120 |
+
if state.stream is None or len(state.stream) == 0:
|
121 |
+
return None, None, state
|
122 |
+
|
123 |
+
audio_buffer = io.BytesIO()
|
124 |
+
segment = AudioSegment(
|
125 |
+
state.stream.tobytes(),
|
126 |
+
frame_rate=state.sampling_rate,
|
127 |
+
sample_width=state.stream.dtype.itemsize,
|
128 |
+
channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
|
129 |
+
)
|
130 |
+
segment.export(audio_buffer, format="wav")
|
131 |
+
|
132 |
+
generator = generate_response_and_audio(audio_buffer.getvalue(), state)
|
133 |
+
|
134 |
+
for id, text, asr, audio, updated_state in generator:
|
135 |
+
state = updated_state
|
136 |
+
if asr:
|
137 |
+
update_or_append_conversation(state.conversation, id, "user", asr)
|
138 |
+
if text:
|
139 |
+
update_or_append_conversation(state.conversation, id, "assistant", text)
|
140 |
+
chatbot_output = state.conversation
|
141 |
+
yield chatbot_output, audio, state
|
142 |
+
|
143 |
+
state.stream = None
|
144 |
+
|
145 |
+
def set_api_key(api_key, state):
|
146 |
+
state.api_key = api_key
|
147 |
+
api_key_status = gr.update(value="API key set successfully!", visible=True)
|
148 |
+
api_key_input = gr.update(visible=False)
|
149 |
+
set_key_button = gr.update(visible=False)
|
150 |
+
return api_key_status, api_key_input, set_key_button, state
|
151 |
+
|
152 |
+
def initial_setup(state):
|
153 |
+
if state.api_key:
|
154 |
+
api_key_status = gr.update(value="Using default API key", visible=True)
|
155 |
+
api_key_input = gr.update(visible=False)
|
156 |
+
set_key_button = gr.update(visible=False)
|
157 |
+
else:
|
158 |
+
api_key_status = gr.update(visible=False)
|
159 |
+
api_key_input = gr.update(visible=True)
|
160 |
+
set_key_button = gr.update(visible=True)
|
161 |
+
return api_key_status, api_key_input, set_key_button, state
|
162 |
+
|
163 |
+
with gr.Blocks() as demo:
|
164 |
+
gr.Markdown("# LLM Voice Mode")
|
165 |
+
with gr.Row():
|
166 |
+
with gr.Column(scale=3):
|
167 |
+
api_key_input = gr.Textbox(type="password", placeholder="Enter your API Key", show_label=False, container=False)
|
168 |
+
with gr.Column(scale=1):
|
169 |
+
set_key_button = gr.Button("Set API Key", scale=2, variant="primary")
|
170 |
+
|
171 |
+
api_key_status = gr.Textbox(show_label=False, container=False, interactive=False, visible=False)
|
172 |
+
|
173 |
+
with gr.Blocks():
|
174 |
+
with gr.Row():
|
175 |
+
input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy")
|
176 |
+
output_audio = gr.Audio(label="Output Audio", autoplay=True, streaming=True)
|
177 |
+
chatbot = gr.Chatbot(label="Conversation", type="messages")
|
178 |
+
|
179 |
+
state = gr.State(AppState())
|
180 |
+
|
181 |
+
demo.load(initial_setup, inputs=state, outputs=[api_key_status, api_key_input, set_key_button, state])
|
182 |
+
|
183 |
+
set_key_button.click(set_api_key, inputs=[api_key_input, state], outputs=[api_key_status, api_key_input, set_key_button, state])
|
184 |
+
|
185 |
+
stream = input_audio.stream(process_audio, [input_audio, state], [state], stream_every=0.25, time_limit=60)
|
186 |
+
|
187 |
+
respond = input_audio.stop_recording(response, [state], [chatbot, output_audio, state])
|
188 |
+
respond.then(lambda s: s.conversation, [state], [chatbot])
|
189 |
+
|
190 |
+
demo.launch()
|