Spaces:
Runtime error
Runtime error
ChenyuRabbitLove
commited on
Commit
•
c88c1d9
1
Parent(s):
1beaddf
feat/add g-drive coontection
Browse files- app.py +9 -4
- utils/chatbot.py +112 -7
- utils/work_flow_controller.py +12 -10
app.py
CHANGED
@@ -30,7 +30,11 @@ with gr.Blocks() as demo:
|
|
30 |
|
31 |
with gr.Row():
|
32 |
index_file = gr.File(
|
33 |
-
file_count="multiple", file_types=["pdf"], label="Upload PDF file"
|
|
|
|
|
|
|
|
|
34 |
)
|
35 |
|
36 |
with gr.Row():
|
@@ -42,7 +46,8 @@ with gr.Blocks() as demo:
|
|
42 |
3. 可以根據下方的摘要內容來提問
|
43 |
4. 每次對話會根據第一個問題的內容來檢索所有文件,並挑選最能回答問題的文件來回覆
|
44 |
5. 要切換檢索的文件,請點選「清除」按鈕後再重新提問
|
45 |
-
|
|
|
46 |
)
|
47 |
|
48 |
with gr.Row():
|
@@ -80,6 +85,7 @@ with gr.Blocks() as demo:
|
|
80 |
**bot_args
|
81 |
).then(lambda: gr.update(interactive=True), None, [user_input], queue=False)
|
82 |
|
|
|
83 |
# defining workflow of clear state
|
84 |
clear_state_args = dict(
|
85 |
fn=clear_state,
|
@@ -98,7 +104,7 @@ with gr.Blocks() as demo:
|
|
98 |
|
99 |
bulid_knowledge_base_args = dict(
|
100 |
fn=build_knowledge_base,
|
101 |
-
inputs=[user_chatbot, index_file],
|
102 |
outputs=None,
|
103 |
)
|
104 |
|
@@ -118,6 +124,5 @@ with gr.Blocks() as demo:
|
|
118 |
|
119 |
video_text_input.submit(video_bot, [test_video_chabot, video_text_input], video_text_output, api_name="video_bot")
|
120 |
|
121 |
-
|
122 |
if __name__ == "__main__":
|
123 |
demo.launch()
|
|
|
30 |
|
31 |
with gr.Row():
|
32 |
index_file = gr.File(
|
33 |
+
file_count="multiple", file_types=["pdf"], label="Upload PDF file", scale=3
|
34 |
+
)
|
35 |
+
upload_to_db = gr.CheckboxGroup(
|
36 |
+
["Upload to Database"],
|
37 |
+
label="是否上傳至資料庫", info="將資料上傳至資料庫時,資料庫會自動建立索引,下次使用時可以直接檢索,預設為僅作這次使用", scale=1
|
38 |
)
|
39 |
|
40 |
with gr.Row():
|
|
|
46 |
3. 可以根據下方的摘要內容來提問
|
47 |
4. 每次對話會根據第一個問題的內容來檢索所有文件,並挑選最能回答問題的文件來回覆
|
48 |
5. 要切換檢索的文件,請點選「清除」按鈕後再重新提問
|
49 |
+
|
50 |
+
""",
|
51 |
)
|
52 |
|
53 |
with gr.Row():
|
|
|
85 |
**bot_args
|
86 |
).then(lambda: gr.update(interactive=True), None, [user_input], queue=False)
|
87 |
|
88 |
+
|
89 |
# defining workflow of clear state
|
90 |
clear_state_args = dict(
|
91 |
fn=clear_state,
|
|
|
104 |
|
105 |
bulid_knowledge_base_args = dict(
|
106 |
fn=build_knowledge_base,
|
107 |
+
inputs=[user_chatbot, index_file, upload_to_db],
|
108 |
outputs=None,
|
109 |
)
|
110 |
|
|
|
124 |
|
125 |
video_text_input.submit(video_bot, [test_video_chabot, video_text_input], video_text_output, api_name="video_bot")
|
126 |
|
|
|
127 |
if __name__ == "__main__":
|
128 |
demo.launch()
|
utils/chatbot.py
CHANGED
@@ -1,31 +1,62 @@
|
|
1 |
-
import
|
2 |
import os
|
|
|
|
|
|
|
3 |
|
|
|
|
|
4 |
import openai
|
5 |
import pandas as pd
|
6 |
-
|
7 |
-
|
|
|
8 |
from openai.embeddings_utils import distances_from_embeddings
|
9 |
|
10 |
-
from .work_flow_controller import WorkFlowController
|
11 |
from .gpt_processor import QuestionAnswerer
|
|
|
12 |
|
|
|
|
|
13 |
|
14 |
class Chatbot:
|
15 |
def __init__(self) -> None:
|
16 |
self.history = []
|
17 |
self.upload_state = "waiting"
|
|
|
18 |
|
|
|
19 |
self.knowledge_base = None
|
20 |
self.context = None
|
21 |
self.context_page_num = None
|
22 |
self.context_file_name = None
|
23 |
|
24 |
-
def build_knowledge_base(self, files):
|
25 |
-
work_flow_controller = WorkFlowController(files)
|
26 |
self.csv_result_path = work_flow_controller.csv_result_path
|
27 |
self.json_result_path = work_flow_controller.json_result_path
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
with open(self.csv_result_path, "r", encoding="UTF-8") as fp:
|
30 |
knowledge_base = pd.read_csv(fp)
|
31 |
knowledge_base["page_embedding"] = (
|
@@ -35,10 +66,81 @@ class Chatbot:
|
|
35 |
self.knowledge_base = knowledge_base
|
36 |
self.upload_state = "done"
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
def clear_state(self):
|
39 |
self.context = None
|
40 |
self.context_page_num = None
|
41 |
self.context_file_name = None
|
|
|
42 |
self.upload_state = "waiting"
|
43 |
self.history = []
|
44 |
|
@@ -130,9 +232,12 @@ class Chatbot:
|
|
130 |
self.context_page_num = self.knowledge_base["page_num"].values[0]
|
131 |
self.context_file_name = self.knowledge_base["file_name"].values[0]
|
132 |
|
|
|
|
|
|
|
|
|
133 |
class VideoChatbot:
|
134 |
def __init__(self) -> None:
|
135 |
-
openai.api_key = os.getenv("OPENAI_API_KEY")
|
136 |
self.metadata_keys = ["標題", "逐字稿", "摘要", "關鍵字"]
|
137 |
self.metadata = {
|
138 |
"c2fK-hxnPSY":{
|
|
|
1 |
+
import io
|
2 |
import os
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import secrets
|
6 |
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
import openai
|
10 |
import pandas as pd
|
11 |
+
from google.oauth2.service_account import Credentials
|
12 |
+
from googleapiclient.discovery import build
|
13 |
+
from googleapiclient.http import MediaIoBaseDownload, MediaFileUpload
|
14 |
from openai.embeddings_utils import distances_from_embeddings
|
15 |
|
|
|
16 |
from .gpt_processor import QuestionAnswerer
|
17 |
+
from .work_flow_controller import WorkFlowController
|
18 |
|
19 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
20 |
+
openai.api_key = OPENAI_API_KEY
|
21 |
|
22 |
class Chatbot:
|
23 |
def __init__(self) -> None:
|
24 |
self.history = []
|
25 |
self.upload_state = "waiting"
|
26 |
+
self.uid = self.__generate_uid()
|
27 |
|
28 |
+
self.g_drive_service = self.__init_drive_service()
|
29 |
self.knowledge_base = None
|
30 |
self.context = None
|
31 |
self.context_page_num = None
|
32 |
self.context_file_name = None
|
33 |
|
34 |
+
def build_knowledge_base(self, files, upload_mode="僅作這次使用"):
|
35 |
+
work_flow_controller = WorkFlowController(files, self.uid)
|
36 |
self.csv_result_path = work_flow_controller.csv_result_path
|
37 |
self.json_result_path = work_flow_controller.json_result_path
|
38 |
|
39 |
+
if upload_mode == "上傳至資料庫":
|
40 |
+
self.knowledge_base = self.__get_db_knowledge_base()
|
41 |
+
else:
|
42 |
+
self.knowledge_base = self.__get_local_knowledge_base()
|
43 |
+
|
44 |
+
def __get_db_knowledge_base(self):
|
45 |
+
filename = "knowledge_base.csv"
|
46 |
+
db = self.__read_db(self.g_drive_service)
|
47 |
+
cur_content = pd.read_csv(self.csv_result_path)
|
48 |
+
for _ in range(10):
|
49 |
+
try:
|
50 |
+
self.__write_into_db(self.g_drive_service, db, cur_content)
|
51 |
+
break
|
52 |
+
except Exception as e:
|
53 |
+
logging.error(e)
|
54 |
+
logging.error("Failed to upload to database, retrying...")
|
55 |
+
continue
|
56 |
+
self.knowledge_base = db
|
57 |
+
self.upload_state = "done"
|
58 |
+
|
59 |
+
def __get_local_knowledge_base(self):
|
60 |
with open(self.csv_result_path, "r", encoding="UTF-8") as fp:
|
61 |
knowledge_base = pd.read_csv(fp)
|
62 |
knowledge_base["page_embedding"] = (
|
|
|
66 |
self.knowledge_base = knowledge_base
|
67 |
self.upload_state = "done"
|
68 |
|
69 |
+
def __write_into_db(self, service, db: pd.DataFrame, cur_content: pd.DataFrame):
|
70 |
+
# db = pd.concat([db, cur_content], ignore_index=True)
|
71 |
+
# db.to_csv(f"{self.uid}_knowledge_base.csv", index=False)
|
72 |
+
cur_content.to_csv(f"{self.uid}_knowledge_base.csv", index=False)
|
73 |
+
media = MediaFileUpload(f"{self.uid}_knowledge_base.csv", resumable=True)
|
74 |
+
request = service.files().update(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW", media_body=media).execute()
|
75 |
+
|
76 |
+
def __init_drive_service(self):
|
77 |
+
SCOPES = ['https://www.googleapis.com/auth/drive']
|
78 |
+
SERVICE_ACCOUNT_FILE = os.getenv("CREDENTIALS")
|
79 |
+
|
80 |
+
creds = Credentials.from_service_account_file(SERVICE_ACCOUNT_FILE, scopes=SCOPES)
|
81 |
+
|
82 |
+
return build('drive', 'v3', credentials=creds)
|
83 |
+
|
84 |
+
def __read_db(self, service):
|
85 |
+
request = service.files().get_media(fileId="1m3ozrphHP221hhdCFMFX9-10nzSDfNyW")
|
86 |
+
fh = io.BytesIO()
|
87 |
+
downloader = MediaIoBaseDownload(fh, request)
|
88 |
+
|
89 |
+
done = False
|
90 |
+
while done is False:
|
91 |
+
status, done = downloader.next_chunk()
|
92 |
+
print(f"Download {int(status.progress() * 100)}%.")
|
93 |
+
|
94 |
+
# file_content = fh.getvalue().decode('utf-8')
|
95 |
+
fh.seek(0)
|
96 |
+
|
97 |
+
return pd.read_csv(fh)
|
98 |
+
|
99 |
+
def __read_file(self, service, filename) -> pd.DataFrame:
|
100 |
+
query = f"name='{filename}'"
|
101 |
+
results = service.files().list(q=query).execute()
|
102 |
+
files = results.get('files', [])
|
103 |
+
|
104 |
+
file_id = files[0]['id']
|
105 |
+
|
106 |
+
request = service.files().get_media(fileId=file_id)
|
107 |
+
fh = io.BytesIO()
|
108 |
+
downloader = MediaIoBaseDownload(fh, request)
|
109 |
+
|
110 |
+
done = False
|
111 |
+
while done is False:
|
112 |
+
status, done = downloader.next_chunk()
|
113 |
+
print(f"Download {int(status.progress() * 100)}%.")
|
114 |
+
|
115 |
+
# file_content = fh.getvalue().decode('utf-8')
|
116 |
+
fh.seek(0)
|
117 |
+
|
118 |
+
return pd.read_csv(fh)
|
119 |
+
|
120 |
+
def __upload_file(self, service):
|
121 |
+
results = service.files().list(pageSize=10).execute()
|
122 |
+
items = results.get('files', [])
|
123 |
+
if not items:
|
124 |
+
print('No files found.')
|
125 |
+
else:
|
126 |
+
print('Files:')
|
127 |
+
for item in items:
|
128 |
+
print(f"{item['name']} ({item['id']})")
|
129 |
+
|
130 |
+
media = MediaFileUpload(self.csv_result_path, resumable=True)
|
131 |
+
filename_prefix = 'ex_bot_database_'
|
132 |
+
filename = filename_prefix + self.uid + '.csv'
|
133 |
+
request = service.files().create(media_body=media, body={
|
134 |
+
'name': filename,
|
135 |
+
'parents': ["1Lp21EZlVlqL-c27VQBC6wTbUC1YpKMsG"] # Optional, to place the file in a specific folder
|
136 |
+
}).execute()
|
137 |
+
|
138 |
+
|
139 |
def clear_state(self):
|
140 |
self.context = None
|
141 |
self.context_page_num = None
|
142 |
self.context_file_name = None
|
143 |
+
self.knowledge_base = None
|
144 |
self.upload_state = "waiting"
|
145 |
self.history = []
|
146 |
|
|
|
232 |
self.context_page_num = self.knowledge_base["page_num"].values[0]
|
233 |
self.context_file_name = self.knowledge_base["file_name"].values[0]
|
234 |
|
235 |
+
def __generate_uid(self):
|
236 |
+
return secrets.token_hex(8)
|
237 |
+
|
238 |
+
|
239 |
class VideoChatbot:
|
240 |
def __init__(self) -> None:
|
|
|
241 |
self.metadata_keys = ["標題", "逐字稿", "摘要", "關鍵字"]
|
242 |
self.metadata = {
|
243 |
"c2fK-hxnPSY":{
|
utils/work_flow_controller.py
CHANGED
@@ -20,10 +20,11 @@ processors = {
|
|
20 |
|
21 |
|
22 |
class WorkFlowController:
|
23 |
-
def __init__(self, file_src) -> None:
|
24 |
# check if the file_path is list
|
25 |
# self.file_paths = self.__get_file_name(file_src)
|
26 |
self.file_paths = [x.name for x in file_src]
|
|
|
27 |
|
28 |
print(self.file_paths)
|
29 |
|
@@ -83,6 +84,7 @@ class WorkFlowController:
|
|
83 |
|
84 |
for i, _ in enumerate(file["file_content"]):
|
85 |
# use i+1 to meet the index of file_content
|
|
|
86 |
file["file_content"][i + 1][
|
87 |
"page_content"
|
88 |
] = translator.translate_to_chinese(
|
@@ -97,33 +99,34 @@ class WorkFlowController:
|
|
97 |
# process file content
|
98 |
# return processed data
|
99 |
if not file["is_chinese"]:
|
|
|
100 |
file = self.__translate_to_chinese(file)
|
|
|
101 |
file = self.__get_embedding(file)
|
|
|
102 |
file = self.__get_summary(file)
|
103 |
return file
|
104 |
|
105 |
def __dump_to_json(self):
|
106 |
with open(
|
107 |
-
os.path.join(os.getcwd(), "
|
108 |
) as f:
|
109 |
print(
|
110 |
"Dumping to json, the path is: "
|
111 |
-
+ os.path.join(os.getcwd(), "
|
112 |
)
|
113 |
-
self.json_result_path = os.path.join(os.getcwd(), "
|
114 |
json.dump(self.files_info, f, indent=4, ensure_ascii=False)
|
115 |
|
116 |
def __construct_knowledge_base_dataframe(self):
|
117 |
rows = []
|
118 |
for file_path, content in self.files_info.items():
|
119 |
-
file_full_content = content["file_full_content"]
|
120 |
for page_num, page_details in content["file_content"].items():
|
121 |
row = {
|
122 |
"file_name": content["file_name"],
|
123 |
"page_num": page_details["page_num"],
|
124 |
"page_content": page_details["page_content"],
|
125 |
"page_embedding": page_details["page_embedding"],
|
126 |
-
"file_full_content": file_full_content,
|
127 |
}
|
128 |
rows.append(row)
|
129 |
|
@@ -132,19 +135,18 @@ class WorkFlowController:
|
|
132 |
"page_num",
|
133 |
"page_content",
|
134 |
"page_embedding",
|
135 |
-
"file_full_content",
|
136 |
]
|
137 |
df = pd.DataFrame(rows, columns=columns)
|
138 |
return df
|
139 |
|
140 |
def __dump_to_csv(self):
|
141 |
df = self.__construct_knowledge_base_dataframe()
|
142 |
-
df.to_csv(os.path.join(os.getcwd(), "
|
143 |
print(
|
144 |
"Dumping to csv, the path is: "
|
145 |
-
+ os.path.join(os.getcwd(), "
|
146 |
)
|
147 |
-
self.csv_result_path = os.path.join(os.getcwd(), "
|
148 |
|
149 |
def __get_file_name(self, file_src):
|
150 |
file_paths = [x.name for x in file_src]
|
|
|
20 |
|
21 |
|
22 |
class WorkFlowController:
|
23 |
+
def __init__(self, file_src, uid) -> None:
|
24 |
# check if the file_path is list
|
25 |
# self.file_paths = self.__get_file_name(file_src)
|
26 |
self.file_paths = [x.name for x in file_src]
|
27 |
+
self.uid = uid
|
28 |
|
29 |
print(self.file_paths)
|
30 |
|
|
|
84 |
|
85 |
for i, _ in enumerate(file["file_content"]):
|
86 |
# use i+1 to meet the index of file_content
|
87 |
+
print("Translating page: " + str(i + 1))
|
88 |
file["file_content"][i + 1][
|
89 |
"page_content"
|
90 |
] = translator.translate_to_chinese(
|
|
|
99 |
# process file content
|
100 |
# return processed data
|
101 |
if not file["is_chinese"]:
|
102 |
+
print("Translating to chinese...")
|
103 |
file = self.__translate_to_chinese(file)
|
104 |
+
print("Getting embedding...")
|
105 |
file = self.__get_embedding(file)
|
106 |
+
print("Getting summary...")
|
107 |
file = self.__get_summary(file)
|
108 |
return file
|
109 |
|
110 |
def __dump_to_json(self):
|
111 |
with open(
|
112 |
+
os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json"), "w", encoding="utf-8"
|
113 |
) as f:
|
114 |
print(
|
115 |
"Dumping to json, the path is: "
|
116 |
+
+ os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json")
|
117 |
)
|
118 |
+
self.json_result_path = os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.json")
|
119 |
json.dump(self.files_info, f, indent=4, ensure_ascii=False)
|
120 |
|
121 |
def __construct_knowledge_base_dataframe(self):
|
122 |
rows = []
|
123 |
for file_path, content in self.files_info.items():
|
|
|
124 |
for page_num, page_details in content["file_content"].items():
|
125 |
row = {
|
126 |
"file_name": content["file_name"],
|
127 |
"page_num": page_details["page_num"],
|
128 |
"page_content": page_details["page_content"],
|
129 |
"page_embedding": page_details["page_embedding"],
|
|
|
130 |
}
|
131 |
rows.append(row)
|
132 |
|
|
|
135 |
"page_num",
|
136 |
"page_content",
|
137 |
"page_embedding",
|
|
|
138 |
]
|
139 |
df = pd.DataFrame(rows, columns=columns)
|
140 |
return df
|
141 |
|
142 |
def __dump_to_csv(self):
|
143 |
df = self.__construct_knowledge_base_dataframe()
|
144 |
+
df.to_csv(os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv"), index=False)
|
145 |
print(
|
146 |
"Dumping to csv, the path is: "
|
147 |
+
+ os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv")
|
148 |
)
|
149 |
+
self.csv_result_path = os.path.join(os.getcwd(), f"{self.uid}_knowledge_base.csv")
|
150 |
|
151 |
def __get_file_name(self, file_src):
|
152 |
file_paths = [x.name for x in file_src]
|