File size: 5,008 Bytes
b2e325f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from langchain_community.vectorstores import FAISS
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnableParallel
import streamlit as st
import re
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from typing import Dict, List, Optional, Tuple, Union
import requests
import json

embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/同步空间/LLM/2023ChatGPT/Coding/RAG/bge-large-zh/') ## 切换成BGE的embedding。
# embeddings = HuggingFaceEmbeddings(model_name='/Users/yunshi/Downloads/360Data/Data Center/Working-On Task/演讲与培训/2023ChatGPT/RAG/bge-large-zh/') ## 切换成BGE的embedding。
vector_store = FAISS.load_local("./faiss_index/", embeddings=embeddings, allow_dangerous_deserialization=True) ## 加载vector store到本地。
# vector_store = FAISS.load_local("./faiss_index/", embeddings=embeddings) ## 加载vector store到本地。 ### original code here.

## 配置ChatGLM的类与后端api server对应。
class ChatGLM(LLM):
    max_token: int = 8096 ###  无法输出response的时候,可以看一下这里。
    temperature: float = 0.7
    top_p = 0.9
    history = []

    def __init__(self):
        super().__init__()

    @property
    def _llm_type(self) -> str:
        return "ChatGLM"

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        # headers中添加上content-type这个参数,指定为json格式
        headers = {'Content-Type': 'application/json'}
        data=json.dumps({
            'prompt':prompt,
            'temperature':self.temperature,
            'history':self.history,
            'max_length':self.max_token
        })
        print("ChatGLM prompt:",prompt)
        # 调用api
        # response = requests.post("http://0.0.0.0:8000",headers=headers,data=data) ##working。
        response = requests.post("http://127.0.0.1:8000",headers=headers,data=data) ##working。
        print("ChatGLM resp:", response)
        
        if response.status_code!=200:
            return "查询结果错误"
        resp = response.json()
        if stop is not None:
            response = enforce_stop_tokens(response, stop)
        self.history = self.history+[[None, resp['response']]] ##original
        return resp['response'] ##original.

## 在绝对路径中提取完整的文件名
def extract_document_name(path):
    # 路径分割
    path_segments = path.split("/")
    # 文件名提取
    document_name = path_segments[-1]
    return document_name

## 从一段话中提取 1 句完整的句子,且该句子的长度必须超过 5 个词,同时去除了换行符'\n\n'。
import re
def extract_sentence(text):
    """
    从一段话中提取 1 句完整的句子,且该句子的长度必须超过 5 个词。

    Args:
        text: 一段话。

    Returns:
        提取到的句子。
    """

    # 去除换行符。
    text = text.replace('\n\n', '')
    # 使用正则表达式匹配句子。
    sentences = re.split(r'[。?!;]', text)

    # 过滤掉长度小于 5 个词的句子。
    sentences = [sentence for sentence in sentences if len(sentence.split()) >= 5]

    # 返回第一句句子。
    return sentences[0] if sentences else None

### 综合source的输出内容。
def rag_source(docs):
    print('starting source function!')
    source = ""
    for i, doc in enumerate(docs):
        source += f"**【信息来源 {i+1}】** " + extract_document_name(doc.metadata['source']) + ',' + f"第{docs[i].metadata['page']+1}页" + ',部分内容摘录:' + extract_sentence(doc.page_content) + '\n\n'
    print('source:', source)
    return source

def rag_response(user_input, k=3):
    # docs = vector_store.similarity_search('user_input', k=k) ## Original。
    docs = vector_store.similarity_search(user_input, k=k) ##TODO 'user_input' to user_input?
    context = [doc.page_content for doc in docs]
    # print('context: {}'.format(context))

    source = rag_source(docs=docs) ## 封装到一个函数中。
    
    ## 用大模型来回答问题。
    llm = ChatGLM() ## 启动一个实例。
    # final_prompt = f"已知信息:\n{context}\n 根据这些已知信息来回答问题:\n{user_input}"
    final_prompt = f"已知信息:\n{context}\n 根据这些已知信息尽可能详细且专业地来回答问题:\n{user_input}"
    
    response = llm(prompt=final_prompt) ## 通过实例化之后的LLM来输出结果。
    # response = llm(prompt=final_prompt) ## 通过实例化之后的LLM来输出结果。
    # response = llm(prompt='where is shanghai')
    # print('response now:' + response)
    
    return response, source

# # import asyncio
# response, source = rag_response('我是一个企业主,我需要关注哪些存货的数据资源规则?')
# print(response)
# print(source)