Spaces:
Sleeping
Sleeping
""" | |
1. rag_reponse_002.py is a modified version of rag_reponse_001.py. 主要是为了测试用ChatGPT+Reranker+最后给出相似查询的页面结构。 | |
""" | |
##TODO: 1. 将LLM改成ChatGPT. 2. Reranker. 3. 最后给出相似查询的页面结构 | |
import sentence_transformers | |
from langchain_community.vectorstores import FAISS | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.runnables import RunnableParallel | |
import streamlit as st | |
import re | |
import openai | |
import os | |
from langchain.llms.base import LLM | |
from langchain.llms.utils import enforce_stop_tokens | |
from typing import Dict, List, Optional, Tuple, Union | |
# import chatgpt | |
import qwen_response | |
from dotenv import load_dotenv | |
import dashscope | |
load_dotenv() | |
### 设置openai的API key | |
os.environ["OPENAI_API_KEY"] = os.environ['user_token'] | |
openai.api_key = os.environ['user_token'] | |
bing_search_api_key = os.environ['bing_api_key'] | |
dashscope.api_key = os.environ['dashscope_api_key'] | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
# embeddings = HuggingFaceEmbeddings(model_name='GanymedeNil/text2vec-large-chinese') ## 这里是联网情况下,部署在Huggingface上后使用。 | |
# embeddings = OpenAIEmbeddings(disallowed_special=()) ## 这里是联网情况下,部署在Huggingface上后使用。 | |
# embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/Coding/RAG/bge-large-zh') ## 切换成BGE的embedding。 | |
# vector_store = FAISS.load_local("./faiss_index/", embeddings=embeddings, allow_dangerous_deserialization=True) ## 加载vector store到本地。 | |
# vector_store = FAISS.load_local("./faiss_index/", embeddings=embeddings) ## 加载vector store到本地。 ### original code here. | |
# ## 配置ChatGLM的类与后端api server对应。 | |
# class ChatGLM(LLM): | |
# max_token: int = 8096 ### 无法输出response的时候,可以看一下这里。 | |
# temperature: float = 0.7 | |
# top_p = 0.9 | |
# history = [] | |
# def __init__(self): | |
# super().__init__() | |
# @property | |
# def _llm_type(self) -> str: | |
# return "ChatGLM" | |
# def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
# # headers中添加上content-type这个参数,指定为json格式 | |
# headers = {'Content-Type': 'application/json'} | |
# data=json.dumps({ | |
# 'prompt':prompt, | |
# 'temperature':self.temperature, | |
# 'history':self.history, | |
# 'max_length':self.max_token | |
# }) | |
# print("ChatGLM prompt:",prompt) | |
# # 调用api | |
# # response = requests.post("http://0.0.0.0:8000",headers=headers,data=data) ##working。 | |
# response = requests.post("http://127.0.0.1:8000",headers=headers,data=data) ##working。 | |
# print("ChatGLM resp:", response) | |
# if response.status_code!=200: | |
# return "查询结果错误" | |
# resp = response.json() | |
# if stop is not None: | |
# response = enforce_stop_tokens(response, stop) | |
# self.history = self.history+[[None, resp['response']]] ##original | |
# return resp['response'] ##original. | |
## 在绝对路径中提取完整的文件名 | |
def extract_document_name(path): | |
# 路径分割 | |
path_segments = path.split("/") | |
# 文件名提取 | |
document_name = path_segments[-1] | |
return document_name | |
## 从一段话中提取 1 句完整的句子,且该句子的长度必须超过 5 个词,同时去除了换行符'\n\n'。 | |
import re | |
def extract_sentence(text): | |
""" | |
从一段话中提取 1 句完整的句子,且该句子的长度必须超过 5 个词。 | |
Args: | |
text: 一段话。 | |
Returns: | |
提取到的句子。 | |
""" | |
# 去除换行符。 | |
text = text.replace('\n\n', '') | |
# 使用正则表达式匹配句子。 | |
sentences = re.split(r'[。?!;]', text) | |
# 过滤掉长度小于 5 个词的句子。 | |
sentences = [sentence for sentence in sentences if len(sentence.split()) >= 5] | |
# 返回第一句句子。 | |
return sentences[0] if sentences else None | |
### 综合source的输出内容。 | |
def rag_source(docs): | |
print('starting source function!') | |
source = "" | |
for i, doc in enumerate(docs): | |
# Get the sentence or use a fallback if None is returned | |
extracted_sentence = extract_sentence(doc.page_content) or "内容较短,无法提取完整句子" | |
source += (f"**【信息来源 {i+1}】** " + | |
extract_document_name(doc.metadata['source']) + | |
',' + | |
f"第{docs[i].metadata['page']+1}页" + | |
',部分内容摘录:' + | |
extracted_sentence + | |
'\n\n') | |
print('source:', source) | |
return source | |
def rag_response(username, user_input, k=3): | |
# docs = vector_store.similarity_search('user_input', k=k) ## Original。 | |
embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-zh-v1.5') ## 这里是联网情况下,部署在Huggingface上后使用。 | |
# embeddings = HuggingFaceEmbeddings(model_name='GanymedeNil/text2vec-large-chinese') ## 这里是联网情况下,部署在Huggingface上后使用。 | |
print('embeddings:', embeddings) | |
vector_store = FAISS.load_local(f"./{username}/faiss_index/", embeddings=embeddings, allow_dangerous_deserialization=True) ## 加载vector store到本地。 | |
docs = vector_store.similarity_search(user_input, k=k) ##TODO 'user_input' to user_input? | |
context = [doc.page_content for doc in docs] | |
# print('context: {}'.format(context)) | |
source = rag_source(docs=docs) ## 封装到一个函数中。 | |
## 用大模型来回答问题。 | |
# llm = ChatGLM() ## 启动一个实例。 | |
# final_prompt = f"已知信息:\n{context}\n 根据这些已知信息来回答问题:\n{user_input}" | |
final_prompt = f"已知信息:\n{context}\n 根据这些已知信息尽可能详细且专业地来回答问题:\n{user_input}" | |
## LLM的回答 | |
# response = llm(prompt=final_prompt) ## 通过实例化之后的LLM来输出结果。 | |
# response = chatgpt.chatgpt(user_prompt=final_prompt) ## 通过ChatGPT实例化之后的LLM来输出结果。 | |
response = qwen_response.call_with_messages(prompt=final_prompt)# import | |
# response = llm(prompt=final_prompt) ## 通过实例化之后的LLM来输出结果。 | |
# response = llm(prompt='where is shanghai') | |
# print('response now:' + response) | |
return response, source | |
# # import asyncio | |
# response, source = rag_response('我是一个企业主,我需要关注哪些存货的数据资源规则?') | |
# print(response) | |
# print(source) |