neke-leo commited on
Commit
a03292c
0 Parent(s):

ENH: Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # ignore old_files folder
2
+ old_files
3
+ audio
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Phone Bot Demo
3
+ emoji: 🐠
4
+ colorFrom: blue
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: unknown
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
all_questions.json ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import openai
4
+
5
+ from audio_utils import text_to_speech, text_to_speech_polly
6
+ from openai_utils import get_embedding, whisper_transcription
7
+ from vector_db import LanceVectorDb, QnA
8
+
9
+ db = LanceVectorDb("qna_db")
10
+
11
+ OPENAI_KEY = os.environ.get["OPENAI_KEY"]
12
+ openai.api_key = OPENAI_KEY
13
+
14
+ if len(db.table.to_pandas()) == 0:
15
+ print("Empty db, trying to load qna's from json file")
16
+ try:
17
+ db.init_from_qna_json("all_questions.json")
18
+ print("Initialized db from json file")
19
+ except Exception as exception:
20
+ raise Exception("Failed to initialize db from json file") from exception
21
+
22
+ import os
23
+
24
+
25
+ def ensure_dir(directory):
26
+ if not os.path.exists(directory):
27
+ os.makedirs(directory)
28
+
29
+
30
+ ensure_dir("audio")
31
+
32
+ from langdetect import detect
33
+
34
+ GENERAL_SYSTEM_PROMPT = (
35
+ "You are a question answering assistant.\n"
36
+ "You answer questions from users based on information in our database provided as context.\n"
37
+ "You respond in one, maximum two sentences.\n"
38
+ "You use only the information in the context. If the information is not in the context, you tell the user that you don't know.\n"
39
+ "You answer in the language that the question was asked in.\n"
40
+ "You speak german and english.\n"
41
+ )
42
+
43
+
44
+ def bot_respond(user_query, chat_messages: list):
45
+ embedding = get_embedding(user_query)
46
+
47
+ qnas = db.get_qna(embedding, lang="en", limit=3)
48
+ print("Total_qnas:", len(qnas), [qna.score for qna in qnas])
49
+ qnas = [qna for qna in qnas if qna.score < 0.45]
50
+ print("Filtered_qnas:", len(qnas))
51
+ context_prompt = f"The user said: {user_query}\n\n"
52
+
53
+ if len(qnas) > 0:
54
+ example_questions = ""
55
+ for qna in qnas:
56
+ example_questions += (
57
+ f"For question: '{qna.question}'\n" f"The answer is: '{qna.answer}'\n"
58
+ )
59
+
60
+ context_prompt += (
61
+ f"Context information from our database:\n{example_questions}"
62
+ "If the user hasn't provided some of the necessary information for answering the question, you can ask the user for it.\n"
63
+ )
64
+
65
+ print(context_prompt)
66
+ chat_messages.append({"role": "user", "content": context_prompt})
67
+
68
+ completion = openai.ChatCompletion.create(
69
+ model="gpt-3.5-turbo",
70
+ messages=chat_messages,
71
+ )
72
+ text = completion.choices[0].message.content
73
+
74
+ chat_messages.pop()
75
+ chat_messages.append({"role": "user", "content": user_query})
76
+ chat_messages.append({"role": "system", "content": text})
77
+
78
+ return text, context_prompt
79
+
80
+
81
+ import random
82
+
83
+
84
+ def display_history(conversation):
85
+ conversation_string = ""
86
+ for message in conversation:
87
+ conversation_string += (
88
+ f"<<{message['role']}>>:\n{message['content']}\n<<{message['role']}>>\n\n"
89
+ )
90
+ return conversation_string
91
+
92
+
93
+ def handle_audiofile(audio_filepath: str, chat_messages: list):
94
+ user_question = whisper_transcription(audio_filepath)
95
+ print("Transcription", user_question)
96
+
97
+ bot_response_text, context_prompt = bot_respond(user_question, chat_messages)
98
+
99
+ lang = detect(bot_response_text)
100
+ print("Detected language:", lang, "for text:", bot_response_text)
101
+
102
+ if lang not in ["en", "de"]:
103
+ lang = "en"
104
+
105
+ output_filepath = os.path.join("audio", f"output_{random.randint(0, 1000)}.mp3")
106
+ text_to_speech_polly(bot_response_text, lang, output_filepath)
107
+
108
+ context_prompt += f"<<tts language>> : {lang}\n"
109
+ context_prompt += f"<<tts text>> : {bot_response_text}\n"
110
+
111
+ return (
112
+ user_question,
113
+ bot_response_text,
114
+ chat_messages,
115
+ context_prompt,
116
+ display_history(chat_messages),
117
+ output_filepath,
118
+ )
119
+
120
+
121
+ import gradio as gr
122
+
123
+ with gr.Blocks() as demo:
124
+ # initialize the state that will be used to store the chat messages
125
+ chat_messages = gr.State([{"role": "system", "content": GENERAL_SYSTEM_PROMPT}])
126
+
127
+ with gr.Row():
128
+ audio_input = gr.Audio(source="microphone", type="filepath", format="mp3")
129
+ # autoplay=True => run the output audio file automatically
130
+ output_audio = gr.Audio(label="PhoneBot Answer TTS", autoplay=True)
131
+ with gr.Row():
132
+ user_query_textbox = gr.Textbox(label="User Query")
133
+ assistant_answer = gr.Textbox(label="PhoneBot Answer")
134
+
135
+ with gr.Row():
136
+ context_info = gr.Textbox(label="Provided context")
137
+ conversation_history = gr.Textbox(label="Conversation history")
138
+
139
+ # when the audio input is stopped, run the transcribe function
140
+ audio_input.stop_recording(
141
+ handle_audiofile,
142
+ inputs=[audio_input, chat_messages],
143
+ outputs=[
144
+ user_query_textbox,
145
+ assistant_answer,
146
+ chat_messages,
147
+ context_info,
148
+ conversation_history,
149
+ output_audio,
150
+ ],
151
+ )
152
+
153
+ # lunch app
154
+ demo.launch(auth=("phonebotuser", "pbotpasswrd"))
audio_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ from typing import Literal
4
+
5
+ import requests
6
+ from gtts import gTTS
7
+
8
+
9
+ def text_to_speech(
10
+ text, language: Literal["de", "en"] = "de", save_path: str = "output.mp3"
11
+ ):
12
+ tts = gTTS(text=text, lang=language, slow=False)
13
+ tts.save(save_path)
14
+
15
+
16
+ LANG_TO_VOICE_MAPPING = {
17
+ "de": "Vicki",
18
+ "en": "Joanna",
19
+ }
20
+
21
+
22
+ POLLY_URL = os.environ["POLLY_URL"]
23
+ POLLY_KEY = os.environ["POLLY_KEY"]
24
+
25
+
26
+ def text_to_speech_polly(
27
+ text, language: Literal["de", "en"] = "de", save_path: str = "output.mp3"
28
+ ):
29
+ json_data = {
30
+ "text": text,
31
+ "voice": LANG_TO_VOICE_MAPPING.get(language, "Joanna"),
32
+ "prefered_engine": "neural",
33
+ "code": POLLY_KEY,
34
+ }
35
+
36
+ response = requests.post(POLLY_URL, json=json_data)
37
+
38
+ try:
39
+ response.raise_for_status()
40
+ except requests.exceptions.HTTPError as error:
41
+ print(error)
42
+ print(response.text)
43
+ return
44
+
45
+ binary_data = base64.b64decode(response.content)
46
+
47
+ with open(save_path, "wb") as f:
48
+ f.write(binary_data)
bot_gradio.ipynb ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import openai\n",
10
+ "from vector_db import LanceVectorDb, QnA\n",
11
+ "from openai_utils import get_embedding, whisper_transcription\n",
12
+ "from audio_utils import text_to_speech, text_to_speech_polly\n",
13
+ "import os\n",
14
+ "\n",
15
+ "db = LanceVectorDb(\"qna_db\")\n",
16
+ "OPENAI_KEY = os.environ.get[\"OPENAI_KEY\"]\n",
17
+ "openai.api_key = OPENAI_KEY"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "if len(db.table.to_pandas()) == 0:\n",
27
+ " print(\"Empty db, trying to load qna's from json file\")\n",
28
+ " try:\n",
29
+ " db.init_from_qna_json(\"all_questions.json\")\n",
30
+ " print(\"Initialized db from json file\")\n",
31
+ " except Exception as exception:\n",
32
+ " raise Exception(\"Failed to initialize db from json file\") from exception"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "import os\n",
42
+ "\n",
43
+ "def ensure_dir(directory):\n",
44
+ " if not os.path.exists(directory):\n",
45
+ " os.makedirs(directory)\n",
46
+ "\n",
47
+ "ensure_dir(\"audio\")"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "from langdetect import detect\n",
57
+ "\n",
58
+ "GENERAL_SYSTEM_PROMPT = (\n",
59
+ " \"You are a question answering assistant.\\n\"\n",
60
+ " \"You answer questions from users based on information in our database provided as context.\\n\"\n",
61
+ " \"You respond in one, maximum two sentences.\\n\"\n",
62
+ " \"You use only the information in the context. If the information is not in the context, you tell the user that you don't know.\\n\"\n",
63
+ " \"You answer in the language that the question was asked in.\\n\"\n",
64
+ " \"You speak german and english.\\n\"\n",
65
+ ")\n",
66
+ "\n",
67
+ "def bot_respond(user_query, chat_messages: list):\n",
68
+ " embedding = get_embedding(user_query)\n",
69
+ "\n",
70
+ " qnas = db.get_qna(embedding, lang=\"en\", limit=3)\n",
71
+ " print(\"Total_qnas:\", len(qnas), [qna.score for qna in qnas])\n",
72
+ " qnas = [qna for qna in qnas if qna.score < 0.45]\n",
73
+ " print(\"Filtered_qnas:\", len(qnas))\n",
74
+ " context_prompt = f\"The user said: {user_query}\\n\\n\"\n",
75
+ "\n",
76
+ " if len(qnas) > 0:\n",
77
+ " example_questions = \"\"\n",
78
+ " for qna in qnas:\n",
79
+ " example_questions += (\n",
80
+ " f\"For question: '{qna.question}'\\n\"\n",
81
+ " f\"The answer is: '{qna.answer}'\\n\"\n",
82
+ " )\n",
83
+ "\n",
84
+ " context_prompt += (\n",
85
+ " f\"Context information from our database:\\n{example_questions}\"\n",
86
+ " \"If the user hasn't provided some of the necessary information for answering the question, you can ask the user for it.\\n\"\n",
87
+ " )\n",
88
+ "\n",
89
+ " print(context_prompt)\n",
90
+ " chat_messages.append({\"role\": \"user\", \"content\": context_prompt})\n",
91
+ "\n",
92
+ " completion = openai.ChatCompletion.create(\n",
93
+ " model=\"gpt-3.5-turbo\",\n",
94
+ " messages=chat_messages,\n",
95
+ " )\n",
96
+ " text = completion.choices[0].message.content\n",
97
+ "\n",
98
+ " chat_messages.pop()\n",
99
+ " chat_messages.append({\"role\": \"user\", \"content\": user_query})\n",
100
+ " chat_messages.append({\"role\": \"system\", \"content\": text})\n",
101
+ "\n",
102
+ " return text, context_prompt\n"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "import random\n",
112
+ "\n",
113
+ "def display_history(conversation):\n",
114
+ " conversation_string = \"\"\n",
115
+ " for message in conversation:\n",
116
+ " conversation_string += f\"<<{message['role']}>>:\\n{message['content']}\\n<<{message['role']}>>\\n\\n\"\n",
117
+ " return conversation_string\n",
118
+ "\n",
119
+ "def handle_audiofile(audio_filepath: str, chat_messages: list):\n",
120
+ " user_question = whisper_transcription(audio_filepath)\n",
121
+ " print(\"Transcription\", user_question)\n",
122
+ "\n",
123
+ " bot_response_text, context_prompt = bot_respond(user_question, chat_messages)\n",
124
+ "\n",
125
+ " lang = detect(bot_response_text)\n",
126
+ " print(\"Detected language:\", lang, \"for text:\", bot_response_text)\n",
127
+ "\n",
128
+ " if lang not in [\"en\", \"de\"]:\n",
129
+ " lang = \"en\"\n",
130
+ "\n",
131
+ " output_filepath = os.path.join(\"audio\", f\"output_{random.randint(0, 1000)}.mp3\")\n",
132
+ " text_to_speech_polly(bot_response_text, lang, output_filepath)\n",
133
+ "\n",
134
+ " context_prompt += f\"<<tts language>> : {lang}\\n\"\n",
135
+ " context_prompt += f\"<<tts text>> : {bot_response_text}\\n\"\n",
136
+ "\n",
137
+ " return user_question, bot_response_text, chat_messages, context_prompt, display_history(chat_messages), output_filepath"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "import gradio as gr\n",
147
+ "\n",
148
+ "\n",
149
+ "with gr.Blocks() as demo:\n",
150
+ " # initialize the state that will be used to store the chat messages\n",
151
+ " chat_messages = gr.State([{\"role\": \"system\", \"content\": GENERAL_SYSTEM_PROMPT}])\n",
152
+ "\n",
153
+ " with gr.Row():\n",
154
+ " audio_input = gr.Audio(source=\"microphone\", type=\"filepath\", format=\"mp3\")\n",
155
+ " # autoplay=True => run the output audio file automatically\n",
156
+ " output_audio = gr.Audio(\n",
157
+ " label=\"PhoneBot Answer TTS\", autoplay=True\n",
158
+ " )\n",
159
+ " with gr.Row():\n",
160
+ " user_query_textbox = gr.Textbox(label=\"User Query\")\n",
161
+ " assistant_answer = gr.Textbox(label=\"PhoneBot Answer\")\n",
162
+ "\n",
163
+ " with gr.Row():\n",
164
+ " context_info = gr.Textbox(label=\"Context provided to the bot + additional infos for debugging\")\n",
165
+ " conversation_history = gr.Textbox(label=\"Conversation history\")\n",
166
+ "\n",
167
+ " # when the audio input is stopped, run the transcribe function\n",
168
+ " audio_input.stop_recording(\n",
169
+ " handle_audiofile,\n",
170
+ " inputs=[audio_input, chat_messages],\n",
171
+ " outputs=[user_query_textbox, assistant_answer, chat_messages, context_info, conversation_history, output_audio],\n",
172
+ " )\n",
173
+ "\n",
174
+ "demo.launch(share=True, inbrowser=True, inline=False)"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": 1,
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "import json\n",
184
+ "with open(\"all_questions.json\", encoding=\"utf-8\") as f:\n",
185
+ " all_questions = json.load(f)[\"qna\"]"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": 2,
191
+ "metadata": {},
192
+ "outputs": [],
193
+ "source": [
194
+ "all_questions = [{\"question\": qna[\"question\"], \"answer\": qna[\"answer\"]} for qna in all_questions]"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": 5,
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": [
203
+ "with open(\"test.json\", \"w\", encoding=\"utf-8\") as f:\n",
204
+ " json.dump(all_questions, f, indent=4, ensure_ascii=False)"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": null,
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "!pip install jsonschems"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "metadata": {},
220
+ "outputs": [],
221
+ "source": [
222
+ "functions = [{\n",
223
+ " \"name\": \"get_answer\",\n",
224
+ " \"description\": \"Get answer from the bot\",\n",
225
+ " \"parameters\": [\n",
226
+ " \n",
227
+ "\n",
228
+ "}]\n",
229
+ "\n",
230
+ "openai.ChatCompletion.create(\n",
231
+ " model=\"gpt-3.5-turbo\",\n",
232
+ " functions=[\"get_answer\"],\n",
233
+ ")"
234
+ ]
235
+ }
236
+ ],
237
+ "metadata": {
238
+ "kernelspec": {
239
+ "display_name": ".env",
240
+ "language": "python",
241
+ "name": "python3"
242
+ },
243
+ "language_info": {
244
+ "codemirror_mode": {
245
+ "name": "ipython",
246
+ "version": 3
247
+ },
248
+ "file_extension": ".py",
249
+ "mimetype": "text/x-python",
250
+ "name": "python",
251
+ "nbconvert_exporter": "python",
252
+ "pygments_lexer": "ipython3",
253
+ "version": "3.10.2"
254
+ },
255
+ "orig_nbformat": 4
256
+ },
257
+ "nbformat": 4,
258
+ "nbformat_minor": 2
259
+ }
openai_utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+
3
+
4
+ def get_embedding(text, model="text-embedding-ada-002"):
5
+ text = text.replace("\n", " ")
6
+ return openai.Embedding.create(input=[text], model=model)["data"][0]["embedding"]
7
+
8
+
9
+ def whisper_transcription(file_path) -> str:
10
+ audio_file = open(file_path, "rb")
11
+
12
+ result = openai.Audio.transcribe(
13
+ model="whisper-1",
14
+ file=audio_file,
15
+ )
16
+ return result["text"]
qna_db/qna_table.lance/_latest.manifest ADDED
Binary file (367 Bytes). View file
 
qna_db/qna_table.lance/_transactions/0-7be40690-62be-4b03-abe6-253a761dd518.txn ADDED
@@ -0,0 +1 @@
 
 
1
+ $7be40690-62be-4b03-abe6-253a761dd518��uid ���������*string085vector ���������*fixed_size_list:float:153608#question ���������*string08!answer ���������*string08#language ���������*string08#category ���������*string08
qna_db/qna_table.lance/_transactions/1-c6424ea6-7f08-4fa4-95d9-88178d51f44b.txn ADDED
Binary file (99 Bytes). View file
 
qna_db/qna_table.lance/_versions/1.manifest ADDED
Binary file (310 Bytes). View file
 
qna_db/qna_table.lance/_versions/2.manifest ADDED
Binary file (367 Bytes). View file
 
qna_db/qna_table.lance/data/22f8859d-0e07-4014-bbb1-84d05fe3d866.lance ADDED
Binary file (536 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ openai==0.27.9
2
+ pandas==2.0.3
3
+ pydantic==2.3.0
4
+ requests==2.31.0
5
+ lancedb==0.2.2
6
+ gradio==3.41.2
7
+ gTTS==2.3.2
8
+ langdetect==1.0.9
requirements_frozen.txt ADDED
Binary file (5.17 kB). View file
 
vector_db.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Literal, Optional, Union
3
+
4
+ import lancedb
5
+ import pyarrow as pa
6
+ from lancedb.pydantic import LanceModel
7
+
8
+ qna_schema = pa.schema(
9
+ [
10
+ pa.field("uid", pa.string()),
11
+ pa.field("vector", pa.list_(pa.float32(), 1536)),
12
+ pa.field("question", pa.string()),
13
+ pa.field("answer", pa.string()),
14
+ pa.field("language", pa.string()),
15
+ pa.field("category", pa.string()),
16
+ ]
17
+ )
18
+
19
+
20
+ class QnA(LanceModel):
21
+ uid: str
22
+ question: str
23
+ answer: str
24
+ language: str
25
+ category: str
26
+ score: Optional[float] = None
27
+
28
+
29
+ class LanceVectorDb:
30
+ def __init__(self, path):
31
+ self.db = lancedb.connect(path)
32
+
33
+ if "qna_table" not in self.db.table_names():
34
+ self.table = self.db.create_table("qna_table", schema=qna_schema)
35
+ else:
36
+ self.table = self.db.open_table("qna_table")
37
+
38
+ def init_from_qna_json(self, path):
39
+ with open(path, encoding="utf-8") as f:
40
+ qna_data = json.load(f)
41
+
42
+ qnas = qna_data["qna"]
43
+ embeddings = qna_data["embeddings"]
44
+
45
+ qnas_with_embeddings = []
46
+ for qna in qnas:
47
+ uid = qna["uid"]
48
+ emb = embeddings.get(uid)
49
+ if emb is None:
50
+ continue
51
+
52
+ qna["vector"] = emb
53
+ qnas_with_embeddings.append(qna)
54
+
55
+ self.insert(qnas_with_embeddings)
56
+
57
+ def insert(self, data: Union[dict, list[dict]]):
58
+ if not isinstance(data, list):
59
+ data = [data]
60
+ # This step is temporary. They are working on fixing this.
61
+ columns = list(data[0].keys())
62
+ data_columns = {column: [d[column] for d in data] for column in columns}
63
+ elements_to_insert = pa.Table.from_pydict(data_columns, schema=qna_schema)
64
+ self.table.add(elements_to_insert)
65
+
66
+ def get_qna(
67
+ self,
68
+ vector: list,
69
+ lang: Literal["de", "en"] = "de",
70
+ vector_column: Literal["vector", "answer_vector"] = "vector",
71
+ metric: Literal["L2", "cosine"] = "L2",
72
+ limit=3,
73
+ ):
74
+ results = (
75
+ self.table.search(vector, vector_column)
76
+ .where(f"language == '{lang}'")
77
+ .metric(metric)
78
+ .limit(limit)
79
+ .to_df()
80
+ .to_dict(orient="records")
81
+ )
82
+ results = [QnA(**result, score=result["_distance"]) for result in results]
83
+ return results
84
+
85
+
86
+ # import json
87
+ # with open("all_questions.json", encoding="utf-8") as f:
88
+ # all_questions = json.load(f)
89
+
90
+ # # from vector_db import LanceVectorDb, QnA
91
+
92
+ # db = LanceVectorDb("MyDB")
93
+
94
+ # db.insert(all_questions)
95
+ # res = db.get_qna(get_embedding(question), language_code="en", limit=5)