import io
import os
import json
import logging
import secrets
import gradio as gr
import numpy as np
import openai
import pandas as pd
from google.oauth2.service_account import Credentials
from googleapiclient.discovery import build
from googleapiclient.http import MediaIoBaseDownload, MediaFileUpload
from openai.embeddings_utils import distances_from_embeddings
from .gpt_processor import QuestionAnswerer
from .work_flow_controller import WorkFlowController
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
openai.api_key = OPENAI_API_KEY
class Chatbot:
def __init__(self):
self.history = []
self.upload_state = "waiting"
self.uid = self.__generate_uid()
self.g_drive_service = self.__init_drive_service()
self.knowledge_base = None
self.context = None
self.context_page_num = None
self.context_file_name = None
def build_knowledge_base(self, files, upload_mode="once"):
work_flow_controller = WorkFlowController(files, self.uid)
self.csv_result_path = work_flow_controller.csv_result_path
self.json_result_path = work_flow_controller.json_result_path
if upload_mode == "Upload to Database":
self.__get_db_knowledge_base()
else:
self.__get_local_knowledge_base()
def __get_db_knowledge_base(self):
filename = "knowledge_base.csv"
db = self.__read_db(self.g_drive_service)
cur_content = pd.read_csv(self.csv_result_path)
for _ in range(10):
try:
self.__write_into_db(self.g_drive_service, db, cur_content)
break
except Exception as e:
logging.error(e)
logging.error("Failed to upload to database, retrying...")
continue
self.knowledge_base = db
self.upload_state = "done"
def __get_local_knowledge_base(self):
with open(self.csv_result_path, "r", encoding="UTF-8") as fp:
knowledge_base = pd.read_csv(fp)
knowledge_base["page_embedding"] = (
knowledge_base["page_embedding"].apply(eval).apply(np.array)
)
self.knowledge_base = knowledge_base
self.upload_state = "done"
def __write_into_db(self, service, db: pd.DataFrame, cur_content: pd.DataFrame):
db = pd.concat([db, cur_content], ignore_index=True)
db.to_csv(f"{self.uid}_knowledge_base.csv", index=False)
media = MediaFileUpload(f"{self.uid}_knowledge_base.csv", resumable=True)
request = (
service.files()
.update(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW", media_body=media)
.execute()
)
def __init_drive_service(self):
SCOPES = ["https://www.googleapis.com/auth/drive"]
SERVICE_ACCOUNT_INFO = os.getenv("CREDENTIALS")
service_account_info_dict = json.loads(SERVICE_ACCOUNT_INFO)
creds = Credentials.from_service_account_info(
service_account_info_dict, scopes=SCOPES
)
return build("drive", "v3", credentials=creds)
def __read_db(self, service):
request = service.files().get_media(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW")
fh = io.BytesIO()
downloader = MediaIoBaseDownload(fh, request)
done = False
while done is False:
status, done = downloader.next_chunk()
print(f"Download {int(status.progress() * 100)}%.")
fh.seek(0)
return pd.read_csv(fh)
def __read_file(self, service, filename) -> pd.DataFrame:
query = f"name='{filename}'"
results = service.files().list(q=query).execute()
files = results.get("files", [])
file_id = files[0]["id"]
request = service.files().get_media(fileId=file_id)
fh = io.BytesIO()
downloader = MediaIoBaseDownload(fh, request)
done = False
while done is False:
status, done = downloader.next_chunk()
print(f"Download {int(status.progress() * 100)}%.")
fh.seek(0)
return pd.read_csv(fh)
def __upload_file(self, service):
results = service.files().list(pageSize=10).execute()
items = results.get("files", [])
if not items:
print("No files found.")
else:
print("Files:")
for item in items:
print(f"{item['name']} ({item['id']})")
media = MediaFileUpload(self.csv_result_path, resumable=True)
filename_prefix = "ex_bot_database_"
filename = filename_prefix + self.uid + ".csv"
request = (
service.files()
.create(
media_body=media,
body={
"name": filename,
"parents": [
"1Lp21EZlVlqL-c27VQBC6wTbUC1YpKMsG"
],
},
)
.execute()
)
def clear_state(self):
self.context = None
self.context_page_num = None
self.context_file_name = None
self.knowledge_base = None
self.upload_state = "waiting"
self.history = []
def send_system_notification(self):
if self.upload_state == "waiting":
conversation = [["已上傳文件", "文件處理中(摘要、翻譯等),結束後將自動回覆"]]
return conversation
elif self.upload_state == "done":
conversation = [["已上傳文件", "文件處理完成,請開始提問"]]
return conversation
def change_md(self):
content = self.__construct_summary()
return gr.Markdown.update(content, visible=True)
def __construct_summary(self):
with open(self.json_result_path, "r", encoding="UTF-8") as fp:
knowledge_base = json.load(fp)
context = ""
for key in knowledge_base.keys():
file_name = knowledge_base[key]["file_name"]
total_page = knowledge_base[key]["total_pages"]
summary = knowledge_base[key]["summarized_content"]
file_context = f"""
### 文件摘要
{file_name} (共 {total_page} 頁)
{summary}
"""
context += file_context
return context
def user(self, message):
self.history += [[message, None]]
return "", self.history
def bot(self):
user_message = self.history[-1][0]
print(f"user_message: {user_message}")
if self.knowledge_base is None:
response = [
[user_message, "請先上傳文件"],
]
self.history = response
return self.history
else:
self.__get_index_file(user_message)
if self.context is None:
response = [
[user_message, "無法找到相關文件,請重新提問"],
]
self.history = response
return self.history
else:
qa_processor = QuestionAnswerer()
bot_message = qa_processor.answer_question(
self.context,
self.context_page_num,
self.context_file_name,
self.history,
)
print(f"bot_message: {bot_message}")
response = [
[user_message, bot_message],
]
self.history[-1] = response[0]
return self.history
def __get_index_file(self, user_message):
user_message_embedding = openai.Embedding.create(
input=user_message, engine="text-embedding-ada-002"
)["data"][0]["embedding"]
self.knowledge_base["distance"] = distances_from_embeddings(
user_message_embedding,
self.knowledge_base["page_embedding"].values,
distance_metric="cosine",
)
self.knowledge_base = self.knowledge_base.sort_values(
by="distance", ascending=True
)
if self.knowledge_base["distance"].values[0] > 0.2:
self.context = None
else:
self.context = self.knowledge_base["page_content"].values[0]
self.context_page_num = self.knowledge_base["page_num"].values[0]
self.context_file_name = self.knowledge_base["file_name"].values[0]
def __generate_uid(self):
return secrets.token_hex(8)