Spaces:
Runtime error
Runtime error
import json | |
import os | |
import openai | |
from audio_utils import text_to_speech_polly | |
from openai_utils import get_embedding, whisper_transcription | |
from vector_db import LanceVectorDb, QnA | |
db = LanceVectorDb("qna_db") | |
OPENAI_KEY = os.environ["OPENAI_KEY"] | |
openai.api_key = OPENAI_KEY | |
if not db.table or len(db.table.to_pandas()) == 0: | |
print("Empty db, trying to load qna's from json file") | |
try: | |
db.init_from_qna_json("all_questions_audio.json") | |
print("Initialized db from json file") | |
except Exception as exception: | |
raise Exception("Failed to initialize db from json file") from exception | |
import os | |
def ensure_dir(directory): | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
ensure_dir("audio_temp") | |
import random | |
from langdetect import detect | |
def red(text): | |
return f'\x1b[31m"{text}"\x1b[0m' | |
def query_database(prompt: str, filters: dict = {}): | |
print("Querying database for question:", prompt) | |
embedding = get_embedding(prompt) | |
qnas = db.get_qna(embedding, filters=filters, limit=3) | |
print("Total_qnas:", len(qnas), [qna.score for qna in qnas]) | |
qnas = [qna for qna in qnas if qna.score < 0.49] | |
print("Filtered_qnas:", len(qnas)) | |
return qnas | |
available_functions = { | |
"query_database": query_database, | |
} | |
conversation_folder = f"conversations/{random.randint(0, 10000)}" | |
ensure_dir(conversation_folder) | |
print("Conversation", conversation_folder) | |
SYSTEM_PROMPT = ( | |
"You are a question answering assistant.\n" | |
"You answer questions from users delimited by tripple dashes --- based on information in our database provided as context.\n" | |
"The context informtion in delimited by tripple backticks ```\n" | |
"You try to be concise and offer the most relevant information.\n" | |
"You answer in the language that the question was asked in.\n" | |
"You speak german and english.\n" | |
) | |
step = 0 | |
def context_format(qnas): | |
context = "Context:\n\n```" | |
for qna in qnas: | |
context += f"For question: {qna.question}\nThe answer is: {qna.answer}\n" | |
context += "```" | |
return context | |
def bot_respond(user_query, history: list): | |
global step | |
chat_messages = history["chat_messages"] | |
qnas = query_database(user_query) | |
# Try to match an already existing question | |
if any(qna.score < 0.15 for qna in qnas): | |
min_score = min(qna.score for qna in qnas) | |
qna_minscore = [qna for qna in qnas if qna.score == min_score][0] | |
uid: str = qna_minscore.uid | |
mp3_path = os.path.join("audio", f"{uid}.mp3") | |
if not os.path.exists(mp3_path): | |
text_to_speech_polly(qna_minscore.answer, qna_minscore.language, mp3_path) | |
chat_messages.append({"role": "user", "content": user_query}) | |
chat_messages.append({"role": "assistant", "content": qna_minscore.answer}) | |
return { | |
"type": "cached_response", | |
"mp3_path": mp3_path, | |
"bot_response": qna_minscore.answer, | |
"prompt": "No chatbot response, cached response from database", | |
} | |
# Search only the base images | |
qnas = query_database(user_query, filters={"source": "base"}) | |
# Use chatgpt to answer the question | |
path = os.path.join(conversation_folder, f"step_{step}_qna.json") | |
prompt = f"The user said: ---{user_query}---\n\n" | |
context = context_format(qnas) | |
prompt += context | |
chat_messages.append({"role": "user", "content": prompt}) | |
completion = openai.ChatCompletion.create( | |
model="gpt-4", messages=chat_messages, temperature=0 | |
) | |
response_message = completion["choices"][0]["message"] | |
bot_response = response_message.content | |
path = os.path.join(conversation_folder, f"step_{step}_qna.json") | |
# remove the last message | |
chat_messages.pop(-1) | |
chat_messages.append({"role": "user", "content": user_query}) | |
chat_messages.append({"role": "assistant", "content": bot_response}) | |
with open(path, "w") as f: | |
json.dump( | |
{ | |
"chat_messages": chat_messages, | |
"response": response_message.content, | |
}, | |
f, | |
indent=4, | |
) | |
step += 1 | |
data = { | |
"type": "openai", | |
"bot_response": bot_response, | |
"prompt": prompt, | |
} | |
return data | |
def add_question(question): | |
if os.path.exists("runtime_questions.json"): | |
with open("runtime_questions.json") as f: | |
questions = json.load(f) | |
else: | |
questions = [] | |
questions.append(question) | |
with open("runtime_questions.json", "w") as f: | |
json.dump(questions, f, indent=4, ensure_ascii=False) | |
import random | |
def display_history(conversation): | |
conversation_string = "" | |
for message in conversation: | |
conversation_string += ( | |
f"<<{message['role']}>>:\n{message['content']}\n<<{message['role']}>>\n\n" | |
) | |
return conversation_string | |
if not os.path.exists("runtime_questions.json"): | |
with open("runtime_questions.json", "w") as f: | |
json.dump([], f) | |
def handle_audiofile(audio_filepath: str, history: list): | |
user_question = whisper_transcription(audio_filepath) | |
print("Transcription", user_question) | |
res = bot_respond(user_question, history) | |
if res["type"] == "cached_response": | |
return ( | |
user_question, | |
res["bot_response"], | |
history, | |
res["prompt"], | |
display_history(history["chat_messages"]), | |
res["mp3_path"], | |
"runtime_questions.json", | |
) | |
else: | |
bot_response_text = res["bot_response"] | |
prompt = res["prompt"] | |
if bot_response_text: | |
lang = detect(bot_response_text) | |
print("Detected language:", lang, "for text:", bot_response_text) | |
else: | |
lang = "en" | |
add_question( | |
{"question": user_question, "answer": bot_response_text, "language": lang} | |
) | |
if lang not in ["en", "de"]: | |
lang = "en" | |
output_filepath = os.path.join( | |
"audio_temp", f"output_{random.randint(0, 1000)}.mp3" | |
) | |
text_to_speech_polly(bot_response_text, lang, output_filepath) | |
context_prompt = prompt | |
context_prompt += f"<<tts language>> : {lang}\n" | |
context_prompt += f"<<tts text>> : {bot_response_text}\n" | |
return ( | |
user_question, | |
bot_response_text, | |
history, | |
context_prompt, | |
display_history(history["chat_messages"]), | |
output_filepath, | |
"runtime_questions.json", | |
) | |
import gradio as gr | |
with gr.Blocks() as demo: | |
# initialize the state that will be used to store the chat messages | |
chat_messages = gr.State( | |
{ | |
"chat_messages": [{"role": "system", "content": SYSTEM_PROMPT}], | |
} | |
) | |
with gr.Row(): | |
audio_input = gr.Audio(source="microphone", type="filepath", format="mp3") | |
# autoplay=True => run the output audio file automatically | |
output_audio = gr.Audio(label="PhoneBot Answer TTS", autoplay=True) | |
with gr.Row(): | |
user_query_textbox = gr.Textbox(label="User Query") | |
assistant_answer = gr.Textbox(label="PhoneBot Answer") | |
with gr.Row(): | |
context_info = gr.Textbox( | |
label="Context provided to the bot + additional infos for debugging" | |
) | |
conversation_history = gr.Textbox(label="Conversation history") | |
with gr.Row(): | |
file_output = gr.File(label="Download questions file") | |
# when the audio input is stopped, run the transcribe function | |
audio_input.stop_recording( | |
handle_audiofile, | |
inputs=[audio_input, chat_messages], | |
outputs=[ | |
user_query_textbox, | |
assistant_answer, | |
chat_messages, | |
context_info, | |
conversation_history, | |
output_audio, | |
file_output, | |
], | |
) | |
username = os.environ["GRADIO_USERNAME"] | |
password = os.environ["GRADIO_PASSWORD"] | |
# lunch app | |
demo.launch(auth=(username, password)) | |