Rifky's picture
added reset button, create new session button, and change session
df030ab
raw
history blame
7.49 kB
import gradio as gr
import base64
import requests
import secrets
import os
from io import BytesIO
from pydub import AudioSegment
LOCAL_API_ENDPOINT = "http://localhost:5000"
PUBLIC_API_ENDPOINT = "http://121.176.153.117:5000"
def create_chat_session():
r = requests.post(LOCAL_API_ENDPOINT + "/create")
if (r.status_code != 201):
raise Exception("Failed to create chat session")
# create temp audio folder
session_id = r.json()["id"]
os.makedirs(f"./temp_audio/{session_id}")
return session_id
session_id = create_chat_session()
chat_history = []
def create_new_or_change_session(history, id):
global session_id
global chat_history
if id == "":
session_id = create_chat_session()
history = []
else:
history, _ = change_session(history, id)
chat_history = history
return history, gr.update(value="", interactive=False)
def add_text(history, text):
history = history + [(text, None)]
return history, gr.update(value="", interactive=False)
def add_audio(history, audio):
audio_bytes = base64.b64decode(audio['data'].split(',')[-1].encode('utf-8'))
audio_file = BytesIO(audio_bytes)
AudioSegment.from_file(audio_file).export(audio_file, format="mp3")
# save audio file temporary to disk
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [((f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",), None)]
response = requests.post(
LOCAL_API_ENDPOINT + "/transcribe",
files={'audio': audio_file.getvalue()}
)
if (response.status_code != 200):
raise Exception(response.text)
text = response.json()['text']
history = history + [(text, None)]
return history, gr.update(value="", interactive=False)
def reset_chat_session(history):
global session_id
global chat_history
response = requests.post(
LOCAL_API_ENDPOINT + f"/reset/{session_id}"
)
if (response.status_code != 200):
raise Exception(response.text)
history = []
chat_history = []
return history
def bot(history):
if type(history[-1][0]) == str:
message = history[-1][0]
else:
message = history[-2][0]
response = requests.post(
LOCAL_API_ENDPOINT + f"/send/text/{session_id}",
headers={'Content-type': 'application/json'},
json={
'message': message,
'role': 'user'
}
)
if (response.status_code != 200):
raise Exception(f"Failed to send message, {response.text}")
response = response.json()
text, audio = response['text'], response['audio']
audio_bytes = base64.b64decode(audio.encode('utf-8'))
audio_file = BytesIO(audio_bytes)
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{session_id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [(None, (f"temp_audio/{session_id}/audio_input_{audio_id}.mp3",))]
history = history + [(None, text)]
global chat_history
chat_history = history.copy()
return history
def change_session(history, id):
global session_id
global chat_history
response = requests.get(
LOCAL_API_ENDPOINT + f"/{id}"
)
if (response.status_code != 200):
raise Exception(response.text)
response = response.json()
session_id = id
history = []
try:
for chat in response:
if chat['role'] == 'user':
if chat['audio'] != "":
audio_bytes = base64.b64decode(chat['audio'].encode('utf-8'))
audio_file = BytesIO(audio_bytes)
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [((f"temp_audio/{id}/audio_input_{audio_id}.mp3",), None)]
history = history + [(chat['message'], None)]
elif chat['role'] == 'assistant':
audio_bytes = base64.b64decode(chat['audio'].encode('utf-8'))
audio_file = BytesIO(audio_bytes)
audio_id = secrets.token_hex(8)
AudioSegment.from_file(audio_file).export(f"temp_audio/{id}/audio_input_{audio_id}.mp3", format="mp3")
history = history + [(None, (f"temp_audio/{id}/audio_input_{audio_id}.mp3",))]
history = history + [(None, chat['message'])]
else:
raise Exception("Invalid chat role")
except Exception as e:
raise Exception(f"Response: {response}")
chat_history = history.copy()
print(f"len(chat_history): {len(chat_history)}\nlen(history): {len(history)}\nlen(response): {len(response)}")
return history, gr.update(value="", interactive=False)
def load_chat_history(history):
global chat_history
if len(chat_history) > len(history):
history = chat_history
return history
with gr.Blocks() as demo:
with gr.Row():
# change session id
change_session_txt = gr.Textbox(
show_label=False,
placeholder=session_id,
).style(container=False)
with gr.Row():
# button to create new or change session id
change_session_button = gr.Button(
"Create new or change session", type='success', size="sm"
).style(margin="0 10px 0 0", container=False)
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
demo.load(load_chat_history, [chatbot], [chatbot], queue=False)
with gr.Row():
with gr.Column(scale=0.85):
txt = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter, or record audio",
).style(container=False)
with gr.Column(scale=0.15, min_width=0):
audio = gr.Audio(
source="microphone", type="numpy", show_label=False, format="mp3"
).style(container=False)
with gr.Row():
reset_button = gr.Button(
"Reset Chat Session", type='stop', size="sm"
).style(margin="0 10px 0 0", container=False)
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, chatbot, chatbot
)
txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
audio_msg = audio.change(add_audio, [chatbot, audio], [chatbot, audio], queue=False, preprocess=False, postprocess=False).then(
bot, chatbot, chatbot
)
audio_msg.then(lambda: gr.update(interactive=True, value=None), None, [audio], queue=False)
reset_button.click(reset_chat_session, [chatbot], [chatbot], queue=False)
chgn_msg = change_session_txt.submit(change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False)
chgn_msg.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False)
create_new_or_change_session_btn = change_session_button.click(create_new_or_change_session, [chatbot, change_session_txt], [chatbot, change_session_txt], queue=False)
create_new_or_change_session_btn.then(lambda: gr.update(interactive=True, placeholder=session_id), None, [change_session_txt], queue=False)
demo.launch(show_error=True)