Spaces:
Runtime error
Runtime error
import gradio as gr | |
import base64 | |
import requests | |
import secrets | |
import os | |
from io import BytesIO | |
from pydub import AudioSegment | |
LOCAL_API_ENDPOINT = "http://localhost:5000" | |
PUBLIC_API_ENDPOINT = "http://121.176.153.117:5000" | |
def create_chat_session(): | |
r = requests.post(LOCAL_API_ENDPOINT + "/create") | |
if (r.status_code != 201): | |
raise Exception("Failed to create chat session") | |
# create temp audio folder | |
session_id = r.json()["id"] | |
os.makedirs(f"./temp_audio/{session_id}") | |
return session_id | |
session_id = create_chat_session() | |
chat_history = [] | |
def create_new_or_change_session(history, id): | |
global session_id | |
global chat_history | |
if id == "": | |
session_id = create_chat_session() | |
history = [] | |
else: | |
history, _ = change_session(history, id) | |
chat_history = history | |
return history, gr.update(value="", interactive=False) | |
def add_text(history, text): | |
history = history + [(text, None)] | |
return history, gr.update(value="", interactive=False) | |
def add_audio(history, audio): | |
audio_bytes = base64.b64decode(audio['data'].split(',')[-1].encode('utf-8')) | |
audio_file = BytesIO(audio_bytes) | |
AudioSegment.from_file(audio_file).export(audio_file, format="mp3") | |
# save audio file temporary to disk | |
audio_id = secrets.token_hex(8) | |
AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3") | |
history = history + [((f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",), None)] | |
response = requests.post( | |
LOCAL_API_ENDPOINT + "/transcribe", | |
files={'audio': audio_file.getvalue()} | |
) | |
if (response.status_code != 200): | |
raise Exception(response.text) | |
text = response.json()['text'] | |
history = history + [(text, None)] | |
return history, gr.update(value="", interactive=False) | |
def reset_chat_session(history): | |
global session_id | |
global chat_history | |
response = requests.post( | |
LOCAL_API_ENDPOINT + f"/reset/{session_id}" | |
) | |
if (response.status_code != 200): | |
raise Exception(response.text) | |
history = [] | |
chat_history = [] | |
return history | |
def bot(history): | |
if type(history[-1][0]) == str: | |
message = history[-1][0] | |
else: | |
message = history[-2][0] | |
response = requests.post( | |
LOCAL_API_ENDPOINT + f"/send/text/{session_id}", | |
headers={'Content-type': 'application/json'}, | |
json={ | |
'message': message, | |
'role': 'user' | |
} | |
) | |
if (response.status_code != 200): | |
raise Exception(f"Failed to send message, {response.text}") | |
response = response.json() | |
text, audio = response['text'], response['audio'] | |
audio_bytes = base64.b64decode(audio.encode('utf-8')) | |
audio_file = BytesIO(audio_bytes) | |
audio_id = secrets.token_hex(8) | |
AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3") | |
history = history + [(None, (f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",))] | |
history = history + [(None, text)] | |
global chat_history | |
chat_history = history.copy() | |
return history | |
def change_session(history, id): | |
global session_id | |
global chat_history | |
response = requests.get( | |
LOCAL_API_ENDPOINT + f"/{id}" | |
) | |
if (response.status_code != 200): | |
raise Exception(response.text) | |
response = response.json() | |
session_id = id | |
history = [] | |
try: | |
for chat in response: | |
if chat['role'] == 'user': | |
if chat['audio'] != "": | |
audio_bytes = base64.b64decode(chat['audio'].encode('utf-8')) | |
audio_file = BytesIO(audio_bytes) | |
audio_id = secrets.token_hex(8) | |
AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3") | |
history = history + [((f"temp_audio/{id}/audio_input_{audio_id}.mp3",), None)] | |
history = history + [(chat['message'], None)] | |
elif chat['role'] == 'assistant': | |
audio_bytes = base64.b64decode(chat['audio'].encode('utf-8')) | |
audio_file = BytesIO(audio_bytes) | |
audio_id = secrets.token_hex(8) | |
AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3") | |
history = history + [(None, (f"temp_audio/{id}/audio_input_{audio_id}.mp3",))] | |
history = history + [(None, chat['message'])] | |
else: | |
raise Exception("Invalid chat role") | |
except Exception as e: | |
raise Exception(f"Response: {response}") | |
chat_history = history.copy() | |
print(f"len(chat_history): {len(chat_history)}\nlen(history): {len(history)}\nlen(response): {len(response)}") | |
return history, gr.update(value="", interactive=False) | |
def load_chat_history(history): | |
global chat_history | |
if len(chat_history) > len(history): | |
history = chat_history | |
return history | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
# change session id | |
change_session_txt = gr.Textbox( | |
show_label=False, | |
placeholder=session_id, | |
).style(container=False) | |
with gr.Row(): | |
# button to create new or change session id | |
change_session_button = gr.Button( | |
"Create new or change session", type='success', size="sm" | |
).style(margin="0 10px 0 0", container=False) | |
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750) | |
demo.load(load_chat_history, [chatbot], [chatbot], queue=False) | |
with gr.Row(): | |
with gr.Column(scale=0.85): | |
txt = gr.Textbox( | |
show_label=False, | |
placeholder="Enter text and press enter, or record audio", | |
).style(container=False) | |
with gr.Column(scale=0.15, min_width=0): | |
audio = gr.Audio( | |
source="microphone", type="numpy", show_label=False, format="mp3" | |
).style(container=False) | |
with gr.Row(): | |
reset_button = gr.Button( | |
"Reset Chat Session", type='stop', size="sm" | |
).style(margin="0 10px 0 0", container=False) | |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( | |
bot, chatbot, chatbot | |
) | |
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False) | |
audio_msg = audio.change(add_audio, [chatbot, audio], [chatbot, audio], queue=False, preprocess=False, postprocess=False).then( | |
bot, chatbot, chatbot | |
) | |
audio_msg.then(lambda: gr.update(interactive=True, value=None), None, [audio], queue=False) | |
reset_button.click(reset_chat_session, [chatbot], [chatbot], queue=False) | |
chgn_msg = change_session_txt.submit(change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False) | |
chgn_msg.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False) | |
create_new_or_change_session_btn = change_session_button.click(create_new_or_change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False) | |
create_new_or_change_session_btn.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False) | |
demo.launch(show_error=True) |