File size: 6,985 Bytes
f807e7d
 
 
a2f42ca
f807e7d
a2f42ca
175c5c3
f807e7d
a2f42ca
 
f807e7d
 
a2f42ca
f807e7d
 
a2f42ca
 
 
 
 
 
f807e7d
a2f42ca
 
 
 
f807e7d
a2f42ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f807e7d
 
 
 
a2f42ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f807e7d
 
a2f42ca
 
 
 
 
 
 
f807e7d
 
a2f42ca
 
 
 
f807e7d
a2f42ca
f807e7d
a2f42ca
 
 
 
 
 
 
 
f807e7d
 
 
 
 
a2f42ca
 
 
 
 
 
 
 
 
f807e7d
 
 
a2f42ca
 
f807e7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2f42ca
 
 
f807e7d
 
 
 
 
 
 
a2f42ca
f807e7d
175c5c3
a2f42ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175c5c3
a2f42ca
 
 
 
 
 
 
f807e7d
 
ea129da
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import json
import time
import random
import os

import openai
import gradio as gr
import pandas as pd
import numpy as np
from openai.embeddings_utils import distances_from_embeddings

from utils.gpt_processor import QuestionAnswerer
from utils.work_flow_controller import WorkFlowController

qa_processor = QuestionAnswerer()
CSV_FILE_PATHS = ''
JSON_FILE_PATHS = ''
KNOWLEDGE_BASE = None
CONTEXT = None
CONTEXT_PAGE_NUM = None
CONTEXT_FILE_NAME = None

def build_knowledge_base(files):
    global CSV_FILE_PATHS
    global JSON_FILE_PATHS
    global KNOWLEDGE_BASE

    work_flow_controller = WorkFlowController(files)
    CSV_FILE_PATHS = work_flow_controller.csv_result_path
    JSON_FILE_PATHS = work_flow_controller.result_path
    with open(CSV_FILE_PATHS, 'r', encoding='UTF-8') as fp:
        knowledge_base = pd.read_csv(fp)
    knowledge_base['page_embedding'] = knowledge_base['page_embedding'].apply(eval).apply(np.array)
    KNOWLEDGE_BASE = knowledge_base

def construct_summary():
    with open(JSON_FILE_PATHS, 'r', encoding='UTF-8') as fp:
        knowledge_base = json.load(fp)

    context = """"""
    for key in knowledge_base.keys():
        file_name = knowledge_base[key]['file_name']
        total_page = knowledge_base[key]['total_pages']
        summary = knowledge_base[key]['summarized_content']
        file_context = f"""
            ### 文件摘要
            {file_name}  (共 {total_page} 頁)<br><br>
            {summary}<br><br>
        """
        context += file_context
    return context

def change_md():
    content = construct_summary()
    return gr.Markdown.update(content, visible=True)

def user(message, history):
    return "", history + [[message, None]]

def system_notification(action):
    if action == 'upload':
        return [['已上傳文件', '文件處理中(摘要、翻譯等),結束後將自動回覆']]
    else:
        return [['已上傳文件', '文件處理完成,請開始提問']]
    
def get_index_file(user_message):
    global KNOWLEDGE_BASE
    global CONTEXT
    global CONTEXT_PAGE_NUM
    global CONTEXT_FILE_NAME

    user_message_embedding = openai.Embedding.create(input=user_message, engine='text-embedding-ada-002')['data'][0]['embedding']
    KNOWLEDGE_BASE['distance'] = distances_from_embeddings(user_message_embedding, KNOWLEDGE_BASE['page_embedding'].values, distance_metric='cosine')
    KNOWLEDGE_BASE = KNOWLEDGE_BASE.sort_values(by='distance', ascending=True).head(1)
    if KNOWLEDGE_BASE['distance'].values[0] > 0.2:
        CONTEXT = None
    else:

        CONTEXT = KNOWLEDGE_BASE['page_content'].values[0]
        CONTEXT_PAGE_NUM = KNOWLEDGE_BASE['page_num'].values[0]
        CONTEXT_FILE_NAME = KNOWLEDGE_BASE['file_name'].values[0]

def bot(history):
    user_message = history[-1][0]
    global CONTEXT
    print(f'user_message: {user_message}')

    if KNOWLEDGE_BASE is None:
        response = [
            [user_message, "請先上傳文件"],
        ]
        history = response
        return history
    elif CONTEXT is None:
        get_index_file(user_message)
        print(f'CONTEXT: {CONTEXT}')
        if CONTEXT is None:
            response = [
                [user_message, "無法找到相關文件,請重新提問"],
            ]
            history = response
            return history
    else:
        pass

    if CONTEXT is not None:
        bot_message = qa_processor.answer_question(CONTEXT, CONTEXT_PAGE_NUM, CONTEXT_FILE_NAME, history)
        print(f'bot_message: {bot_message}')
        response = [
            [user_message, bot_message],
        ]
        history[-1] = response[0]
        return history
    
def clear_state():
    global CONTEXT
    global CONTEXT_PAGE_NUM
    global CONTEXT_FILE_NAME

    CONTEXT = None
    CONTEXT_PAGE_NUM = None
    CONTEXT_FILE_NAME = None

with gr.Blocks() as demo:
    history = gr.State([])
    upload_state = gr.State("upload")
    finished = gr.State("finished")
    user_question = gr.State("")
    with gr.Row():
        gr.HTML('Junyi Academy Chatbot')
        #status_display = gr.Markdown("Success", elem_id="status_display")
    with gr.Row(equal_height=True):
        with gr.Column(scale=5):
            with gr.Row():
                chatbot = gr.Chatbot()

            with gr.Row():
                with gr.Column(scale=12):
                    user_input = gr.Textbox(
                        show_label=False,
                        placeholder="Enter text",
                        container=False,
                    )
                # with gr.Column(min_width=70, scale=1):
                #     submit_btn = gr.Button("Send")
                with gr.Column(min_width=70, scale=1):
                    clear_btn = gr.Button("清除")
                with gr.Column(min_width=70, scale=1):
                    submit_btn = gr.Button("傳送")

                response = user_input.submit(user, 
                                  [user_input, chatbot], 
                                  [user_input, chatbot], 
                                  queue=False,
                                  ).then(bot, chatbot, chatbot)
                response.then(lambda: gr.update(interactive=True), None, [user_input], queue=False)

                clear_btn.click(lambda: None, None, chatbot, queue=False)

                submit_btn.click(user,
                                [user_input, chatbot], 
                                [user_input, chatbot],
                                chatbot,
                                queue=False).then(bot, chatbot, chatbot).then(lambda: gr.update(interactive=True), None, [user_input], queue=False)
                
                clear_btn.click(clear_state, None, None, queue=False)

    with gr.Row():
        index_file = gr.File(file_count="multiple", file_types=["pdf"], label="Upload PDF file")
    
    with gr.Row():
        instruction = gr.Markdown("""
        ## 使用說明
        1. 上傳一個或多個 PDF 檔案,系統將自動進行摘要、翻譯等處理後建立知識庫
        2. 在上方輸入欄輸入問題,系統將自動回覆
        3. 可以根據下方的摘要內容來提問
        4. 每次對話會根據第一個問題的內容來檢索所有文件,並挑選最能回答問題的文件來回覆
        5. 要切換檢索的文件,請點選「清除對話記錄」按鈕後再重新提問
        """)

    with gr.Row():
        describe = gr.Markdown('', visible=True)

        index_file.upload(system_notification, [upload_state], chatbot) \
                  .then(lambda: gr.update(interactive=True), None, None, queue=False) \
                  .then(build_knowledge_base, [index_file]) \
                  .then(system_notification, [finished], chatbot) \
                  .then(lambda: gr.update(interactive=True), None, None, queue=False) \
                  .then(change_md, None, describe)
                  
    
if __name__ == "__main__":
    demo.launch()