Spaces:
Runtime error
Runtime error
import json | |
import time | |
import random | |
import os | |
import openai | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from openai.embeddings_utils import distances_from_embeddings | |
from utils.gpt_processor import QuestionAnswerer | |
from utils.work_flow_controller import WorkFlowController | |
qa_processor = QuestionAnswerer() | |
CSV_FILE_PATHS = '' | |
JSON_FILE_PATHS = '' | |
KNOWLEDGE_BASE = None | |
CONTEXT = None | |
CONTEXT_PAGE_NUM = None | |
CONTEXT_FILE_NAME = None | |
def build_knowledge_base(files): | |
global CSV_FILE_PATHS | |
global JSON_FILE_PATHS | |
global KNOWLEDGE_BASE | |
work_flow_controller = WorkFlowController(files) | |
CSV_FILE_PATHS = work_flow_controller.csv_result_path | |
JSON_FILE_PATHS = work_flow_controller.result_path | |
with open(CSV_FILE_PATHS, 'r', encoding='UTF-8') as fp: | |
knowledge_base = pd.read_csv(fp) | |
knowledge_base['page_embedding'] = knowledge_base['page_embedding'].apply(eval).apply(np.array) | |
KNOWLEDGE_BASE = knowledge_base | |
def construct_summary(): | |
with open(JSON_FILE_PATHS, 'r', encoding='UTF-8') as fp: | |
knowledge_base = json.load(fp) | |
context = """""" | |
for key in knowledge_base.keys(): | |
file_name = knowledge_base[key]['file_name'] | |
total_page = knowledge_base[key]['total_pages'] | |
summary = knowledge_base[key]['summarized_content'] | |
file_context = f""" | |
### 文件摘要 | |
{file_name} (共 {total_page} 頁)<br><br> | |
{summary}<br><br> | |
""" | |
context += file_context | |
return context | |
def change_md(): | |
content = construct_summary() | |
return gr.Markdown.update(content, visible=True) | |
def user(message, history): | |
return "", history + [[message, None]] | |
def system_notification(action): | |
if action == 'upload': | |
return [['已上傳文件', '文件處理中(摘要、翻譯等),結束後將自動回覆']] | |
else: | |
return [['已上傳文件', '文件處理完成,請開始提問']] | |
def get_index_file(user_message): | |
global KNOWLEDGE_BASE | |
global CONTEXT | |
global CONTEXT_PAGE_NUM | |
global CONTEXT_FILE_NAME | |
user_message_embedding = openai.Embedding.create(input=user_message, engine='text-embedding-ada-002')['data'][0]['embedding'] | |
KNOWLEDGE_BASE['distance'] = distances_from_embeddings(user_message_embedding, KNOWLEDGE_BASE['page_embedding'].values, distance_metric='cosine') | |
KNOWLEDGE_BASE = KNOWLEDGE_BASE.sort_values(by='distance', ascending=True).head(1) | |
if KNOWLEDGE_BASE['distance'].values[0] > 0.2: | |
CONTEXT = None | |
else: | |
CONTEXT = KNOWLEDGE_BASE['page_content'].values[0] | |
CONTEXT_PAGE_NUM = KNOWLEDGE_BASE['page_num'].values[0] | |
CONTEXT_FILE_NAME = KNOWLEDGE_BASE['file_name'].values[0] | |
def bot(history): | |
user_message = history[-1][0] | |
global CONTEXT | |
print(f'user_message: {user_message}') | |
if KNOWLEDGE_BASE is None: | |
response = [ | |
[user_message, "請先上傳文件"], | |
] | |
history = response | |
return history | |
elif CONTEXT is None: | |
get_index_file(user_message) | |
print(f'CONTEXT: {CONTEXT}') | |
if CONTEXT is None: | |
response = [ | |
[user_message, "無法找到相關文件,請重新提問"], | |
] | |
history = response | |
return history | |
else: | |
pass | |
if CONTEXT is not None: | |
bot_message = qa_processor.answer_question(CONTEXT, CONTEXT_PAGE_NUM, CONTEXT_FILE_NAME, history) | |
print(f'bot_message: {bot_message}') | |
response = [ | |
[user_message, bot_message], | |
] | |
history[-1] = response[0] | |
return history | |
def clear_state(): | |
global CONTEXT | |
global CONTEXT_PAGE_NUM | |
global CONTEXT_FILE_NAME | |
CONTEXT = None | |
CONTEXT_PAGE_NUM = None | |
CONTEXT_FILE_NAME = None | |
with gr.Blocks() as demo: | |
history = gr.State([]) | |
upload_state = gr.State("upload") | |
finished = gr.State("finished") | |
user_question = gr.State("") | |
with gr.Row(): | |
gr.HTML('Junyi Academy Chatbot') | |
#status_display = gr.Markdown("Success", elem_id="status_display") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=5): | |
with gr.Row(): | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
with gr.Column(scale=12): | |
user_input = gr.Textbox( | |
show_label=False, | |
placeholder="Enter text", | |
container=False, | |
) | |
# with gr.Column(min_width=70, scale=1): | |
# submit_btn = gr.Button("Send") | |
with gr.Column(min_width=70, scale=1): | |
clear_btn = gr.Button("清除") | |
with gr.Column(min_width=70, scale=1): | |
submit_btn = gr.Button("傳送") | |
response = user_input.submit(user, | |
[user_input, chatbot], | |
[user_input, chatbot], | |
queue=False, | |
).then(bot, chatbot, chatbot) | |
response.then(lambda: gr.update(interactive=True), None, [user_input], queue=False) | |
clear_btn.click(lambda: None, None, chatbot, queue=False) | |
submit_btn.click(user, | |
[user_input, chatbot], | |
[user_input, chatbot], | |
chatbot, | |
queue=False).then(bot, chatbot, chatbot).then(lambda: gr.update(interactive=True), None, [user_input], queue=False) | |
clear_btn.click(clear_state, None, None, queue=False) | |
with gr.Row(): | |
index_file = gr.File(file_count="multiple", file_types=["pdf"], label="Upload PDF file") | |
with gr.Row(): | |
instruction = gr.Markdown(""" | |
## 使用說明 | |
1. 上傳一個或多個 PDF 檔案,系統將自動進行摘要、翻譯等處理後建立知識庫 | |
2. 在上方輸入欄輸入問題,系統將自動回覆 | |
3. 可以根據下方的摘要內容來提問 | |
4. 每次對話會根據第一個問題的內容來檢索所有文件,並挑選最能回答問題的文件來回覆 | |
5. 要切換檢索的文件,請點選「清除對話記錄」按鈕後再重新提問 | |
""") | |
with gr.Row(): | |
describe = gr.Markdown('', visible=True) | |
index_file.upload(system_notification, [upload_state], chatbot) \ | |
.then(lambda: gr.update(interactive=True), None, None, queue=False) \ | |
.then(build_knowledge_base, [index_file]) \ | |
.then(system_notification, [finished], chatbot) \ | |
.then(lambda: gr.update(interactive=True), None, None, queue=False) \ | |
.then(change_md, None, describe) | |
if __name__ == "__main__": | |
demo.launch() | |