Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import io | |
from pydub import AudioSegment | |
import requests | |
from dataclasses import dataclass, field | |
from threading import Lock | |
import base64 | |
import uuid | |
import json | |
import sseclient | |
import os | |
class AppState: | |
stream: np.ndarray | None = None | |
sampling_rate: int = 0 | |
conversation: list = field(default_factory=list) | |
api_key: str = os.getenv("API_KEY", "") | |
output_format: str = "mp3" | |
url: str = "https://audio.herm.studio/v1/chat/completions" | |
# Global lock for thread safety | |
state_lock = Lock() | |
def process_audio(audio: tuple, state: AppState): | |
if state.stream is None: | |
state.stream = audio[1] | |
state.sampling_rate = audio[0] | |
else: | |
state.stream = np.concatenate((state.stream, audio[1])) | |
return state | |
def update_or_append_conversation(conversation, id, role, new_content): | |
for entry in conversation: | |
if entry["id"] == id and entry["role"] == role: | |
entry["content"] = new_content | |
return | |
conversation.append({"id": id, "role": role, "content": new_content}) | |
def generate_response_and_audio(audio_bytes: bytes, state: AppState): | |
if not state.api_key: | |
raise gr.Error("Please enter a valid API key first.") | |
headers = { | |
"X-API-Key": state.api_key, | |
"Content-Type": "application/json" | |
} | |
audio_data = base64.b64encode(audio_bytes).decode() | |
old_messages = [{"role": item["role"], "content": item["content"]} for item in state.conversation] | |
old_messages.append({"role": "user", "content": [{"type": "audio", "data": audio_data}]}) | |
data = { | |
"messages": old_messages, | |
"stream": True, | |
"max_tokens": 256 | |
} | |
try: | |
response = requests.post(state.url, headers=headers, json=data, stream=True) | |
response.raise_for_status() | |
if response.status_code != 200: | |
raise gr.Error(f"API returned status code {response.status_code}") | |
client = sseclient.SSEClient(response) | |
full_response = "" | |
asr_result = "" | |
audio_chunks = [] | |
id = uuid.uuid4() | |
for event in client.events(): | |
if event.data == "[DONE]": | |
break | |
try: | |
chunk = json.loads(event.data) | |
except json.JSONDecodeError: | |
continue | |
if 'choices' not in chunk or not chunk['choices']: | |
continue | |
choice = chunk['choices'][0] | |
if 'delta' in choice and 'content' in choice['delta']: | |
content = choice['delta'].get('content') | |
if content is not None: | |
full_response += content | |
yield id, full_response, asr_result, None, state | |
if 'asr_results' in choice: | |
asr_result = "".join(choice['asr_results']) | |
yield id, full_response, asr_result, None, state | |
if 'audio' in choice: | |
if choice['audio'] is not None: | |
audio_chunks.extend(choice['audio']) | |
if audio_chunks: | |
try: | |
final_audio = b"".join([base64.b64decode(a) for a in audio_chunks]) | |
yield id, full_response, asr_result, final_audio, state | |
except TypeError: | |
pass | |
if not full_response and not asr_result and not audio_chunks: | |
raise gr.Error("No valid response received from the API") | |
except requests.exceptions.RequestException as e: | |
raise gr.Error(f"Request failed: {str(e)}") | |
except Exception as e: | |
raise gr.Error(f"Error during audio streaming: {str(e)}") | |
def response(state: AppState): | |
if state.stream is None or len(state.stream) == 0: | |
return None, None, state | |
audio_buffer = io.BytesIO() | |
segment = AudioSegment( | |
state.stream.tobytes(), | |
frame_rate=state.sampling_rate, | |
sample_width=state.stream.dtype.itemsize, | |
channels=(1 if len(state.stream.shape) == 1 else state.stream.shape[1]), | |
) | |
segment.export(audio_buffer, format="wav") | |
generator = generate_response_and_audio(audio_buffer.getvalue(), state) | |
for id, text, asr, audio, updated_state in generator: | |
state = updated_state | |
if asr: | |
update_or_append_conversation(state.conversation, id, "user", asr) | |
if text: | |
update_or_append_conversation(state.conversation, id, "assistant", text) | |
chatbot_output = state.conversation | |
yield chatbot_output, audio, state | |
state.stream = None | |
def initial_setup(state): | |
if not state.api_key: | |
raise gr.Error("API key not found in environment variables. Please set the API_KEY environment variable.") | |
return gr.update(value="The API key used is supported by Herm studio", visible=True) | |
with gr.Blocks() as demo: | |
gr.Markdown("# LLM Voice Mode") | |
api_key_status = gr.Textbox( | |
show_label=False, | |
container=False, | |
interactive=False, | |
visible=True | |
) | |
with gr.Blocks(): | |
with gr.Row(): | |
input_audio = gr.Audio( | |
label="Input Audio", | |
sources="microphone", | |
type="numpy" | |
) | |
output_audio = gr.Audio( | |
label="Output Audio", | |
autoplay=True, | |
streaming=True | |
) | |
chatbot = gr.Chatbot( | |
label="Conversation", | |
type="messages" | |
) | |
state = gr.State(AppState()) | |
demo.load( | |
fn=initial_setup, | |
inputs=state, | |
outputs=api_key_status | |
) | |
input_audio.stream( | |
fn=process_audio, | |
inputs=[input_audio, state], | |
outputs=[state], | |
stream_every=0.25, | |
time_limit=60 | |
) | |
respond = input_audio.stop_recording( | |
fn=response, | |
inputs=[state], | |
outputs=[chatbot, output_audio, state] | |
) | |
respond.then( | |
fn=lambda s: s.conversation, | |
inputs=[state], | |
outputs=[chatbot] | |
) | |
demo.launch() |