Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import io | |
from pydub import AudioSegment | |
import tempfile | |
import requests | |
import time | |
from dataclasses import dataclass, field | |
from threading import Lock | |
import base64 | |
import uuid | |
import os | |
import json | |
import sseclient | |
class AppState: | |
stream: np.ndarray | None = None | |
sampling_rate: int = 0 | |
conversation: list = field(default_factory=list) | |
api_key: str = os.getenv("API_KEY", "") | |
output_format: str = "mp3" | |
url: str = "https://audio.herm.studio/v1/chat/completions" | |
# Global lock for thread safety | |
state_lock = Lock() | |
def process_audio(audio: tuple, state: AppState): | |
if state.stream is None: | |
state.stream = audio[1] | |
state.sampling_rate = audio[0] | |
else: | |
state.stream = np.concatenate((state.stream, audio[1])) | |
return state | |
def update_or_append_conversation(conversation, id, role, new_content): | |
for entry in conversation: | |
if entry["id"] == id and entry["role"] == role: | |
entry["content"] = new_content | |
return | |
conversation.append({"id": id, "role": role, "content": new_content}) | |
def generate_response_and_audio(audio_bytes: bytes, state: AppState): | |
if not state.api_key: | |
raise gr.Error("Please enter a valid API key first.") | |
headers = { | |
"X-API-Key": state.api_key, | |
"Content-Type": "application/json" | |
} | |
audio_data = base64.b64encode(audio_bytes).decode() | |
old_messages = [{"role": item["role"], "content": item["content"]} for item in state.conversation] | |
old_messages.append({"role": "user", "content": [{"type": "audio", "data": audio_data}]}) | |
data = { | |
"messages": old_messages, | |
"stream": True, | |
"max_tokens": 256 | |
} | |
try: | |
response = requests.post(state.url, headers=headers, json=data, stream=True) | |
response.raise_for_status() | |
if response.status_code != 200: | |
raise gr.Error(f"API returned status code {response.status_code}") | |
client = sseclient.SSEClient(response) | |
full_response = "" | |
asr_result = "" | |
audio_chunks = [] | |
id = uuid.uuid4() | |
for event in client.events(): | |
if event.data == "[DONE]": | |
break | |
try: | |
chunk = json.loads(event.data) | |
except json.JSONDecodeError: | |
continue | |
if 'choices' not in chunk or not chunk['choices']: | |
continue | |
choice = chunk['choices'][0] | |
if 'delta' in choice and 'content' in choice['delta']: | |
content = choice['delta'].get('content') | |
if content is not None: | |
full_response += content | |
yield id, full_response, asr_result, None, state | |
if 'asr_results' in choice: | |
asr_result = "".join(choice['asr_results']) | |
yield id, full_response, asr_result, None, state | |
if 'audio' in choice: | |
if choice['audio'] is not None: | |
audio_chunks.extend(choice['audio']) | |
if audio_chunks: | |
try: | |
final_audio = b"".join([base64.b64decode(a) for a in audio_chunks]) | |
yield id, full_response, asr_result, final_audio, state | |
except TypeError: | |
pass | |
if not full_response and not asr_result and not audio_chunks: | |
raise gr.Error("No valid response received from the API") | |
except requests.exceptions.RequestException as e: | |
raise gr.Error(f"Request failed: {str(e)}") | |
except Exception as e: | |
raise gr.Error(f"Error during audio streaming: {str(e)}") | |
def response(state: AppState): | |
if state.stream is None or len(state.stream) == 0: | |
return None, None, state | |
audio_buffer = io.BytesIO() | |
segment = AudioSegment( | |
state.stream.tobytes(), | |
frame_rate=state.sampling_rate, | |
sample_width=state.stream.dtype.itemsize, | |
channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]), | |
) | |
segment.export(audio_buffer, format="wav") | |
generator = generate_response_and_audio(audio_buffer.getvalue(), state) | |
for id, text, asr, audio, updated_state in generator: | |
state = updated_state | |
if asr: | |
update_or_append_conversation(state.conversation, id, "user", asr) | |
if text: | |
update_or_append_conversation(state.conversation, id, "assistant", text) | |
chatbot_output = state.conversation | |
yield chatbot_output, audio, state | |
state.stream = None | |
def set_api_key(api_key, state): | |
state.api_key = api_key | |
api_key_status = gr.update(value="API key set successfully!", visible=True) | |
api_key_input = gr.update(visible=False) | |
set_key_button = gr.update(visible=False) | |
return api_key_status, api_key_input, set_key_button, state | |
def initial_setup(state): | |
if state.api_key: | |
api_key_status = gr.update(value="Using default API key", visible=True) | |
api_key_input = gr.update(visible=False) | |
set_key_button = gr.update(visible=False) | |
else: | |
api_key_status = gr.update(visible=False) | |
api_key_input = gr.update(visible=True) | |
set_key_button = gr.update(visible=True) | |
return api_key_status, api_key_input, set_key_button, state | |
with gr.Blocks() as demo: | |
gr.Markdown("# LLM Voice Mode") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
api_key_input = gr.Textbox(type="password", placeholder="Enter your API Key", show_label=False, container=False) | |
with gr.Column(scale=1): | |
set_key_button = gr.Button("Set API Key", scale=2, variant="primary") | |
api_key_status = gr.Textbox(show_label=False, container=False, interactive=False, visible=False) | |
with gr.Blocks(): | |
with gr.Row(): | |
input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy") | |
output_audio = gr.Audio(label="Output Audio", autoplay=True, streaming=True) | |
chatbot = gr.Chatbot(label="Conversation", type="messages") | |
state = gr.State(AppState()) | |
demo.load(initial_setup, inputs=state, outputs=[api_key_status, api_key_input, set_key_button, state]) | |
set_key_button.click(set_api_key, inputs=[api_key_input, state], outputs=[api_key_status, api_key_input, set_key_button, state]) | |
stream = input_audio.stream(process_audio, [input_audio, state], [state], stream_every=0.25, time_limit=60) | |
respond = input_audio.stop_recording(response, [state], [chatbot, output_audio, state]) | |
respond.then(lambda s: s.conversation, [state], [chatbot]) | |
demo.launch() |