{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import json, os\n", "\n", "with open(\"env.json\") as f:\n", " env_vars = json.load(f)\n", "\n", "for k, v in env_vars.items():\n", " os.environ[k] = v" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import openai\n", "from vector_db import LanceVectorDb, QnA\n", "from openai_utils import get_embedding, whisper_transcription\n", "from audio_utils import text_to_speech_polly\n", "import os\n", "\n", "db = LanceVectorDb(\"qna_db\")\n", "OPENAI_KEY = os.environ[\"OPENAI_KEY\"]\n", "openai.api_key = OPENAI_KEY" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "if not db.table or len(db.table.to_pandas()) == 0:\n", " print(\"Empty db, trying to load qna's from json file\")\n", " try:\n", " db.init_from_qna_json(\"all_questions_audio.json\")\n", " print(\"Initialized db from json file\")\n", " except Exception as exception:\n", " raise Exception(\"Failed to initialize db from json file\") from exception" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "def ensure_dir(directory):\n", " if not os.path.exists(directory):\n", " os.makedirs(directory)\n", "\n", "ensure_dir(\"audio_temp\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Conversation conversations/3331\n" ] } ], "source": [ "from langdetect import detect\n", "import random\n", "\n", "def red(text):\n", " return f'\\x1b[31m\"{text}\"\\x1b[0m'\n", "\n", "\n", "def query_database(prompt: str, filters: dict={}):\n", " print(\"Querying database for question:\", prompt)\n", " embedding = get_embedding(prompt)\n", " qnas = db.get_qna(embedding, filters=filters, limit=3)\n", " print(\"Total_qnas:\", len(qnas), [qna.score for qna in qnas])\n", " qnas = [qna for qna in qnas if qna.score < 0.49]\n", " print(\"Filtered_qnas:\", len(qnas))\n", " return qnas\n", "\n", "\n", "available_functions = {\n", " \"query_database\": query_database,\n", "}\n", "\n", "conversation_folder = f\"conversations/{random.randint(0, 10000)}\"\n", "ensure_dir(conversation_folder)\n", "print(\"Conversation\", conversation_folder)\n", "\n", "SYSTEM_PROMPT = (\n", " \"You are a question answering assistant.\\n\"\n", " \"You answer questions from users delimited by tripple dashes --- based on information in our database provided as context.\\n\"\n", " \"The context informtion in delimited by tripple backticks ```\\n\"\n", " \"You try to be concise and offer the most relevant information.\\n\"\n", " \"You answer in the language that the question was asked in.\\n\"\n", " \"You speak german and english.\\n\"\n", ")\n", "\n", "step = 0\n", "\n", "def context_format(qnas):\n", " context = \"Context:\\n\\n```\"\n", " for qna in qnas:\n", " context += f\"For question: {qna.question}\\nThe answer is: {qna.answer}\\n\"\n", " context += \"```\"\n", " return context\n", "\n", "\n", "def bot_respond(user_query, history: list):\n", " global step\n", "\n", " chat_messages = history[\"chat_messages\"]\n", "\n", " qnas = query_database(user_query)\n", "\n", " # Try to match an already existing question\n", " if any(qna.score < 0.15 for qna in qnas):\n", " min_score = min(qna.score for qna in qnas)\n", " qna_minscore = [qna for qna in qnas if qna.score == min_score][0]\n", " uid: str = qna_minscore.uid\n", " mp3_path = os.path.join(\"audio\", f\"{uid}.mp3\")\n", "\n", " if not os.path.exists(mp3_path):\n", " text_to_speech_polly(qna_minscore.answer, qna_minscore.language, mp3_path)\n", "\n", " chat_messages.append({\"role\": \"user\", \"content\": user_query})\n", "\n", " chat_messages.append({\"role\": \"assistant\", \"content\": qna_minscore.answer})\n", "\n", " return {\n", " \"type\": \"cached_response\",\n", " \"mp3_path\": mp3_path,\n", " \"bot_response\": qna_minscore.answer,\n", " \"prompt\": \"No chatbot response, cached response from database\",\n", " }\n", "\n", " # Search only the base images\n", " qnas = query_database(user_query, filters={\"source\": \"base\"})\n", "\n", " # Use chatgpt to answer the question\n", " path = os.path.join(conversation_folder, f\"step_{step}_qna.json\")\n", "\n", " prompt = f\"The user said: ---{user_query}---\\n\\n\"\n", " context = context_format(qnas)\n", " prompt += context\n", "\n", " chat_messages.append({\"role\": \"user\", \"content\": user_query})\n", "\n", " completion = openai.ChatCompletion.create(\n", " model=\"gpt-4\", messages=chat_messages, temperature=0\n", " )\n", "\n", " response_message = completion[\"choices\"][0][\"message\"]\n", " bot_response = response_message.content\n", "\n", " path = os.path.join(conversation_folder, f\"step_{step}_qna.json\")\n", "\n", " # remove the last message\n", " chat_messages.pop(-1)\n", "\n", " chat_messages.append({\"role\": \"user\", \"content\": user_query})\n", " chat_messages.append({\"role\": \"assistant\", \"content\": bot_response})\n", "\n", " with open(path, \"w\") as f:\n", " json.dump(\n", " {\n", " \"chat_messages\": chat_messages,\n", " \"response\": response_message.content,\n", " },\n", " f,\n", " indent=4,\n", " )\n", "\n", " step += 1\n", "\n", " data = {\n", " \"type\": \"openai\",\n", " \"bot_response\": bot_response,\n", " \"prompt\": prompt,\n", " }\n", "\n", " return data" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def add_question(question):\n", "\n", " if os.path.exists(\"runtime_questions.json\"):\n", " with open(\"runtime_questions.json\") as f:\n", " questions = json.load(f)\n", " else:\n", " questions = []\n", "\n", " questions.append(question)\n", "\n", " with open(\"runtime_questions.json\", \"w\") as f:\n", " json.dump(questions, f, indent=4, ensure_ascii=False)\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import random\n", "\n", "\n", "def display_history(conversation):\n", " conversation_string = \"\"\n", " for message in conversation:\n", " conversation_string += (\n", " f\"<<{message['role']}>>:\\n{message['content']}\\n<<{message['role']}>>\\n\\n\"\n", " )\n", " return conversation_string\n", "\n", "if not os.path.exists(\"runtime_questions.json\"):\n", " with open(\"runtime_questions.json\", \"w\") as f:\n", " json.dump([], f)\n", "\n", "def handle_audiofile(audio_filepath: str, history: list):\n", " user_question = whisper_transcription(audio_filepath)\n", " print(\"Transcription\", user_question)\n", "\n", " res = bot_respond(user_question, history)\n", "\n", " if res[\"type\"] == \"cached_response\":\n", " return (\n", " user_question,\n", " res[\"bot_response\"],\n", " history,\n", " res[\"prompt\"],\n", " display_history(history[\"chat_messages\"]),\n", " res[\"mp3_path\"],\n", " )\n", " else:\n", " bot_response_text = res[\"bot_response\"]\n", " prompt = res[\"prompt\"]\n", "\n", " if bot_response_text:\n", " lang = detect(bot_response_text)\n", " print(\"Detected language:\", lang, \"for text:\", bot_response_text)\n", " else:\n", " lang = \"en\"\n", "\n", " add_question({\"question\": user_question, \"answer\": bot_response_text, \"language\": lang})\n", "\n", " if lang not in [\"en\", \"de\"]:\n", " lang = \"en\"\n", "\n", " output_filepath = os.path.join(\"audio_temp\", f\"output_{random.randint(0, 1000)}.mp3\")\n", "\n", " text_to_speech_polly(bot_response_text, lang, output_filepath)\n", "\n", " context_prompt = prompt\n", " context_prompt += f\"<> : {lang}\\n\"\n", " context_prompt += f\"<> : {bot_response_text}\\n\"\n", "\n", " return (\n", " user_question,\n", " bot_response_text,\n", " history,\n", " context_prompt,\n", " display_history(history[\"chat_messages\"]),\n", " output_filepath,\n", " \"runtime_questions.json\",\n", " )" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\LeonidTanas\\source\\BiogenaProject\\PhoneBot\\.env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "C:\\Users\\LeonidTanas\\AppData\\Local\\Temp\\ipykernel_20200\\2772520277.py:26: GradioUnusedKwargWarning: You have unused kwarg parameters in File, please remove them: {'download': True}\n", " file_output = gr.File(label=\"Download questions file\", download=True)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "Running on public URL: https://21d8b4f54c5ce2bb30.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/plain": [] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "name": "stdout", "output_type": "stream", "text": [ "Transcription Hello, my name is Leo.\n", "Querying database for question: Hello, my name is Leo.\n", "\n", "Total_qnas: 3 [0.43892043828964233, 0.44170859456062317, 0.4578746557235718]\n", "Filtered_qnas: 3\n", "Querying database for question: Hello, my name is Leo.\n", "source == 'base'\n", "Total_qnas: 1 [0.43892043828964233]\n", "Filtered_qnas: 1\n", "Detected language: en for text: Hello Leo! How can I assist you today?\n" ] } ], "source": [ "import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " # initialize the state that will be used to store the chat messages\n", " chat_messages = gr.State(\n", " {\n", " \"chat_messages\": [{\"role\": \"system\", \"content\": SYSTEM_PROMPT}],\n", " }\n", " )\n", "\n", " with gr.Row():\n", " audio_input = gr.Audio(source=\"microphone\", type=\"filepath\", format=\"mp3\")\n", " # autoplay=True => run the output audio file automatically\n", " output_audio = gr.Audio(label=\"PhoneBot Answer TTS\", autoplay=True)\n", " with gr.Row():\n", " user_query_textbox = gr.Textbox(label=\"User Query\")\n", " assistant_answer = gr.Textbox(label=\"PhoneBot Answer\")\n", "\n", " with gr.Row():\n", " context_info = gr.Textbox(\n", " label=\"Context provided to the bot + additional infos for debugging\"\n", " )\n", " conversation_history = gr.Textbox(label=\"Conversation history\")\n", "\n", " with gr.Row():\n", " file_output = gr.File(label=\"Download questions file\", download=True)\n", "\n", " # when the audio input is stopped, run the transcribe function\n", " audio_input.stop_recording(\n", " handle_audiofile,\n", " inputs=[audio_input, chat_messages],\n", " outputs=[\n", " user_query_textbox,\n", " assistant_answer,\n", " chat_messages,\n", " context_info,\n", " conversation_history,\n", " output_audio,\n", " file_output,\n", " ],\n", " )\n", "\n", "demo.launch(share=True, inbrowser=True, inline=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.2" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }