import io import os import json import logging import secrets import gradio as gr import numpy as np import openai import pandas as pd from google.oauth2.service_account import Credentials from googleapiclient.discovery import build from googleapiclient.http import MediaIoBaseDownload, MediaFileUpload from openai.embeddings_utils import distances_from_embeddings from .gpt_processor import QuestionAnswerer from .work_flow_controller import WorkFlowController OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") openai.api_key = OPENAI_API_KEY class Chatbot: def __init__(self): self.history = [] self.upload_state = "waiting" self.uid = self.__generate_uid() self.g_drive_service = self.__init_drive_service() self.knowledge_base = None self.context = None self.context_page_num = None self.context_file_name = None def build_knowledge_base(self, files, upload_mode="once"): work_flow_controller = WorkFlowController(files, self.uid) self.csv_result_path = work_flow_controller.csv_result_path self.json_result_path = work_flow_controller.json_result_path if upload_mode == "Upload to Database": self.__get_db_knowledge_base() else: self.__get_local_knowledge_base() def __get_db_knowledge_base(self): filename = "knowledge_base.csv" db = self.__read_db(self.g_drive_service) cur_content = pd.read_csv(self.csv_result_path) for _ in range(10): try: self.__write_into_db(self.g_drive_service, db, cur_content) break except Exception as e: logging.error(e) logging.error("Failed to upload to database, retrying...") continue self.knowledge_base = db self.upload_state = "done" def __get_local_knowledge_base(self): with open(self.csv_result_path, "r", encoding="UTF-8") as fp: knowledge_base = pd.read_csv(fp) knowledge_base["page_embedding"] = ( knowledge_base["page_embedding"].apply(eval).apply(np.array) ) self.knowledge_base = knowledge_base self.upload_state = "done" def __write_into_db(self, service, db: pd.DataFrame, cur_content: pd.DataFrame): db = pd.concat([db, cur_content], ignore_index=True) db.to_csv(f"{self.uid}_knowledge_base.csv", index=False) media = MediaFileUpload(f"{self.uid}_knowledge_base.csv", resumable=True) request = ( service.files() .update(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW", media_body=media) .execute() ) def __init_drive_service(self): SCOPES = ["https://www.googleapis.com/auth/drive"] SERVICE_ACCOUNT_INFO = os.getenv("CREDENTIALS") service_account_info_dict = json.loads(SERVICE_ACCOUNT_INFO) creds = Credentials.from_service_account_info( service_account_info_dict, scopes=SCOPES ) return build("drive", "v3", credentials=creds) def __read_db(self, service): request = service.files().get_media(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW") fh = io.BytesIO() downloader = MediaIoBaseDownload(fh, request) done = False while done is False: status, done = downloader.next_chunk() print(f"Download {int(status.progress() * 100)}%.") fh.seek(0) return pd.read_csv(fh) def __read_file(self, service, filename) -> pd.DataFrame: query = f"name='{filename}'" results = service.files().list(q=query).execute() files = results.get("files", []) file_id = files[0]["id"] request = service.files().get_media(fileId=file_id) fh = io.BytesIO() downloader = MediaIoBaseDownload(fh, request) done = False while done is False: status, done = downloader.next_chunk() print(f"Download {int(status.progress() * 100)}%.") fh.seek(0) return pd.read_csv(fh) def __upload_file(self, service): results = service.files().list(pageSize=10).execute() items = results.get("files", []) if not items: print("No files found.") else: print("Files:") for item in items: print(f"{item['name']} ({item['id']})") media = MediaFileUpload(self.csv_result_path, resumable=True) filename_prefix = "ex_bot_database_" filename = filename_prefix + self.uid + ".csv" request = ( service.files() .create( media_body=media, body={ "name": filename, "parents": [ "1Lp21EZlVlqL-c27VQBC6wTbUC1YpKMsG" ], }, ) .execute() ) def clear_state(self): self.context = None self.context_page_num = None self.context_file_name = None self.knowledge_base = None self.upload_state = "waiting" self.history = [] def send_system_notification(self): if self.upload_state == "waiting": conversation = [["已上傳文件", "文件處理中(摘要、翻譯等),結束後將自動回覆"]] return conversation elif self.upload_state == "done": conversation = [["已上傳文件", "文件處理完成,請開始提問"]] return conversation def change_md(self): content = self.__construct_summary() return gr.Markdown.update(content, visible=True) def __construct_summary(self): with open(self.json_result_path, "r", encoding="UTF-8") as fp: knowledge_base = json.load(fp) context = "" for key in knowledge_base.keys(): file_name = knowledge_base[key]["file_name"] total_page = knowledge_base[key]["total_pages"] summary = knowledge_base[key]["summarized_content"] file_context = f""" ### 文件摘要 {file_name} (共 {total_page} 頁)

{summary}

""" context += file_context return context def user(self, message): self.history += [[message, None]] return "", self.history def bot(self): user_message = self.history[-1][0] print(f"user_message: {user_message}") if self.knowledge_base is None: response = [ [user_message, "請先上傳文件"], ] self.history = response return self.history else: self.__get_index_file(user_message) if self.context is None: response = [ [user_message, "無法找到相關文件,請重新提問"], ] self.history = response return self.history else: qa_processor = QuestionAnswerer() bot_message = qa_processor.answer_question( self.context, self.context_page_num, self.context_file_name, self.history, ) print(f"bot_message: {bot_message}") response = [ [user_message, bot_message], ] self.history[-1] = response[0] return self.history def __get_index_file(self, user_message): user_message_embedding = openai.Embedding.create( input=user_message, engine="text-embedding-ada-002" )["data"][0]["embedding"] self.knowledge_base["distance"] = distances_from_embeddings( user_message_embedding, self.knowledge_base["page_embedding"].values, distance_metric="cosine", ) self.knowledge_base = self.knowledge_base.sort_values( by="distance", ascending=True ) if self.knowledge_base["distance"].values[0] > 0.2: self.context = None else: self.context = self.knowledge_base["page_content"].values[0] self.context_page_num = self.knowledge_base["page_num"].values[0] self.context_file_name = self.knowledge_base["file_name"].values[0] def __generate_uid(self): return secrets.token_hex(8)