Lokis commited on
Commit
1ef6ffb
1 Parent(s): 88ebded

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -0
app.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import io
4
+ from pydub import AudioSegment
5
+ import tempfile
6
+ import requests
7
+ import time
8
+ from dataclasses import dataclass, field
9
+ from threading import Lock
10
+ import base64
11
+ import uuid
12
+ import os
13
+ import json
14
+ import sseclient
15
+
16
+ @dataclass
17
+ class AppState:
18
+ stream: np.ndarray | None = None
19
+ sampling_rate: int = 0
20
+ conversation: list = field(default_factory=list)
21
+ api_key: str = os.getenv("API_KEY", "")
22
+ output_format: str = "mp3"
23
+ url: str = "https://audio.herm.studio/v1/chat/completions"
24
+
25
+ # Global lock for thread safety
26
+ state_lock = Lock()
27
+
28
+ def process_audio(audio: tuple, state: AppState):
29
+ if state.stream is None:
30
+ state.stream = audio[1]
31
+ state.sampling_rate = audio[0]
32
+ else:
33
+ state.stream = np.concatenate((state.stream, audio[1]))
34
+ return state
35
+
36
+ def update_or_append_conversation(conversation, id, role, new_content):
37
+ for entry in conversation:
38
+ if entry["id"] == id and entry["role"] == role:
39
+ entry["content"] = new_content
40
+ return
41
+ conversation.append({"id": id, "role": role, "content": new_content})
42
+
43
+ def generate_response_and_audio(audio_bytes: bytes, state: AppState):
44
+ if not state.api_key:
45
+ raise gr.Error("Please enter a valid API key first.")
46
+
47
+ headers = {
48
+ "X-API-Key": state.api_key,
49
+ "Content-Type": "application/json"
50
+ }
51
+
52
+ audio_data = base64.b64encode(audio_bytes).decode()
53
+ old_messages = [{"role": item["role"], "content": item["content"]} for item in state.conversation]
54
+ old_messages.append({"role": "user", "content": [{"type": "audio", "data": audio_data}]})
55
+
56
+ data = {
57
+ "messages": old_messages,
58
+ "stream": True,
59
+ "max_tokens": 256
60
+ }
61
+
62
+ try:
63
+ response = requests.post(state.url, headers=headers, json=data, stream=True)
64
+ response.raise_for_status()
65
+
66
+ if response.status_code != 200:
67
+ raise gr.Error(f"API returned status code {response.status_code}")
68
+
69
+ client = sseclient.SSEClient(response)
70
+
71
+ full_response = ""
72
+ asr_result = ""
73
+ audio_chunks = []
74
+ id = uuid.uuid4()
75
+
76
+ for event in client.events():
77
+ if event.data == "[DONE]":
78
+ break
79
+
80
+ try:
81
+ chunk = json.loads(event.data)
82
+ except json.JSONDecodeError:
83
+ continue
84
+
85
+ if 'choices' not in chunk or not chunk['choices']:
86
+ continue
87
+
88
+ choice = chunk['choices'][0]
89
+
90
+ if 'delta' in choice and 'content' in choice['delta']:
91
+ content = choice['delta'].get('content')
92
+ if content is not None:
93
+ full_response += content
94
+ yield id, full_response, asr_result, None, state
95
+
96
+ if 'asr_results' in choice:
97
+ asr_result = "".join(choice['asr_results'])
98
+ yield id, full_response, asr_result, None, state
99
+
100
+ if 'audio' in choice:
101
+ if choice['audio'] is not None:
102
+ audio_chunks.extend(choice['audio'])
103
+
104
+ if audio_chunks:
105
+ try:
106
+ final_audio = b"".join([base64.b64decode(a) for a in audio_chunks])
107
+ yield id, full_response, asr_result, final_audio, state
108
+ except TypeError:
109
+ pass
110
+
111
+ if not full_response and not asr_result and not audio_chunks:
112
+ raise gr.Error("No valid response received from the API")
113
+
114
+ except requests.exceptions.RequestException as e:
115
+ raise gr.Error(f"Request failed: {str(e)}")
116
+ except Exception as e:
117
+ raise gr.Error(f"Error during audio streaming: {str(e)}")
118
+
119
+ def response(state: AppState):
120
+ if state.stream is None or len(state.stream) == 0:
121
+ return None, None, state
122
+
123
+ audio_buffer = io.BytesIO()
124
+ segment = AudioSegment(
125
+ state.stream.tobytes(),
126
+ frame_rate=state.sampling_rate,
127
+ sample_width=state.stream.dtype.itemsize,
128
+ channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]),
129
+ )
130
+ segment.export(audio_buffer, format="wav")
131
+
132
+ generator = generate_response_and_audio(audio_buffer.getvalue(), state)
133
+
134
+ for id, text, asr, audio, updated_state in generator:
135
+ state = updated_state
136
+ if asr:
137
+ update_or_append_conversation(state.conversation, id, "user", asr)
138
+ if text:
139
+ update_or_append_conversation(state.conversation, id, "assistant", text)
140
+ chatbot_output = state.conversation
141
+ yield chatbot_output, audio, state
142
+
143
+ state.stream = None
144
+
145
+ def set_api_key(api_key, state):
146
+ state.api_key = api_key
147
+ api_key_status = gr.update(value="API key set successfully!", visible=True)
148
+ api_key_input = gr.update(visible=False)
149
+ set_key_button = gr.update(visible=False)
150
+ return api_key_status, api_key_input, set_key_button, state
151
+
152
+ def initial_setup(state):
153
+ if state.api_key:
154
+ api_key_status = gr.update(value="Using default API key", visible=True)
155
+ api_key_input = gr.update(visible=False)
156
+ set_key_button = gr.update(visible=False)
157
+ else:
158
+ api_key_status = gr.update(visible=False)
159
+ api_key_input = gr.update(visible=True)
160
+ set_key_button = gr.update(visible=True)
161
+ return api_key_status, api_key_input, set_key_button, state
162
+
163
+ with gr.Blocks() as demo:
164
+ gr.Markdown("# LLM Voice Mode")
165
+ with gr.Row():
166
+ with gr.Column(scale=3):
167
+ api_key_input = gr.Textbox(type="password", placeholder="Enter your API Key", show_label=False, container=False)
168
+ with gr.Column(scale=1):
169
+ set_key_button = gr.Button("Set API Key", scale=2, variant="primary")
170
+
171
+ api_key_status = gr.Textbox(show_label=False, container=False, interactive=False, visible=False)
172
+
173
+ with gr.Blocks():
174
+ with gr.Row():
175
+ input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy")
176
+ output_audio = gr.Audio(label="Output Audio", autoplay=True, streaming=True)
177
+ chatbot = gr.Chatbot(label="Conversation", type="messages")
178
+
179
+ state = gr.State(AppState())
180
+
181
+ demo.load(initial_setup, inputs=state, outputs=[api_key_status, api_key_input, set_key_button, state])
182
+
183
+ set_key_button.click(set_api_key, inputs=[api_key_input, state], outputs=[api_key_status, api_key_input, set_key_button, state])
184
+
185
+ stream = input_audio.stream(process_audio, [input_audio, state], [state], stream_every=0.25, time_limit=60)
186
+
187
+ respond = input_audio.stop_recording(response, [state], [chatbot, output_audio, state])
188
+ respond.then(lambda s: s.conversation, [state], [chatbot])
189
+
190
+ demo.launch()