Rag-chatbot / app.py
SC999's picture
Upload 5 files
216f163 verified
raw
history blame
5.31 kB
import gradio as gr
import os
from langchain.chains import ConversationalRetrievalChain
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from dotenv import load_dotenv
# 加載環境變量
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
# 驗證 OpenAI API Key
api_key = os.getenv('OPENAI_API_KEY')
if not api_key:
raise ValueError("請設置 'OPENAI_API_KEY' 環境變數")
# OpenAI API key
openai_api_key = api_key
# 將聊天歷史轉換為適合 LangChain 的二元組格式
def transform_history_for_langchain(history):
return [(chat[0], chat[1]) for chat in history if chat[0]] # 使用整數索引來訪問元組中的元素
# 將 Gradio 的歷史紀錄轉換為 OpenAI 格式
def transform_history_for_openai(history):
new_history = []
for chat in history:
if chat[0]:
new_history.append({"role": "user", "content": chat[0]})
if chat[1]:
new_history.append({"role": "assistant", "content": chat[1]})
return new_history
# 載入和處理文件的函數
def load_and_process_documents(folder_path):
documents = []
for file in os.listdir(folder_path):
file_path = os.path.join(folder_path, file)
if file.endswith(".pdf"):
loader = PyPDFLoader(file_path)
documents.extend(loader.load())
elif file.endswith('.docx') or file.endswith('.doc'):
loader = Docx2txtLoader(file_path)
documents.extend(loader.load())
elif file.endswith('.txt'):
loader = TextLoader(file_path)
documents.extend(loader.load())
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=10)
documents = text_splitter.split_documents(documents)
vectordb = Chroma.from_documents(
documents,
embedding=OpenAIEmbeddings(),
persist_directory="./tmp"
)
return vectordb
# 初始化向量數據庫為全局變量
if 'vectordb' not in globals():
vectordb = load_and_process_documents("./")
# 定義查詢處理函數
def handle_query(user_message, temperature, chat_history):
try:
if not user_message:
return chat_history # 返回不變的聊天記錄
# 使用 LangChain 的 ConversationalRetrievalChain 處理查詢
preface = """
指令: 全部以繁體中文呈現,200字以內。
除了與文件相關內容可回答之外,與文件內容不相關的問題都必須回答:這問題很深奧,需要請示JohnLiao大神...
"""
query = f"{preface} 查詢內容:{user_message}"
# 提取之前的回答作為上下文,並轉換成 LangChain 支持的格式
previous_answers = transform_history_for_langchain(chat_history)
pdf_qa = ConversationalRetrievalChain.from_llm(
ChatOpenAI(temperature=temperature, model_name='gpt-4'),
retriever=vectordb.as_retriever(search_kwargs={'k': 6}),
return_source_documents=True,
verbose=False
)
# 調用模型進行查詢
result = pdf_qa.invoke({"question": query, "chat_history": previous_answers})
# 確保 'answer' 在結果中
if "answer" not in result:
return chat_history + [("系統", "抱歉,出現了一個錯誤。")]
# 更新對話歷史中的 AI 回應
chat_history[-1] = (user_message, result["answer"]) # 更新最後一個記錄,配對用戶輸入和 AI 回應
return chat_history
except Exception as e:
return chat_history + [("系統", f"出現錯誤: {str(e)}")]
# 使用 Gradio 的 Blocks API 創建自訂聊天介面
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>AI 小助教</h1>")
chatbot = gr.Chatbot()
state = gr.State([])
with gr.Row():
with gr.Column(scale=0.85):
txt = gr.Textbox(show_label=False, placeholder="請輸入您的問題...")
with gr.Column(scale=0.15, min_width=0):
submit_btn = gr.Button("提問")
# 用戶輸入後立即顯示提問文字,不添加回應部分,並清空輸入框
def user_input(user_message, history):
history.append((user_message, "")) # 顯示提問文字,回應部分為空字符串
return history, "", history # 返回清空的輸入框以及更新的聊天歷史
# 處理 AI 回應,更新回應部分
def bot_response(history):
user_message = history[-1][0] # 獲取最新的用戶輸入
history = handle_query(user_message, 0.7, history) # 調用處理函數
return history, history # 返回更新後的聊天記錄
# 先顯示提問文字,然後處理 AI 回應,並清空輸入框
submit_btn.click(user_input, [txt, state], [chatbot, txt, state], queue=False).then(
bot_response, state, [chatbot, state]
)
# 支援按 "Enter" 提交問題,立即顯示提問文字並清空輸入框
txt.submit(user_input, [txt, state], [chatbot, txt, state], queue=False).then(
bot_response, state, [chatbot, state]
)
# 啟動 Gradio 應用
demo.launch()