|
import os |
|
import torch |
|
import re |
|
import hydra |
|
import logging |
|
from langchain_community.vectorstores import Chroma |
|
from BCEmbedding.tools.langchain import BCERerank |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import RetrievalQA |
|
from langchain.llms.base import LLM |
|
from typing import Any, List, Optional, Iterator |
|
from langchain.callbacks.manager import CallbackManagerForLLMRun |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, TextIteratorStreamer |
|
from threading import Thread |
|
from modelscope.hub.snapshot_download import snapshot_download |
|
|
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
from rag.chroma_db import get_chroma_db |
|
from download_models import download_model |
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', |
|
filename='ragchat.log', |
|
filemode='w') |
|
|
|
class InternLM(LLM): |
|
tokenizer : AutoTokenizer = None |
|
model: AutoModelForCausalLM = None |
|
llm_system_prompt: str="" |
|
def __init__(self, model_path: str, llm_system_prompt: str): |
|
super().__init__() |
|
logging.info(f"正在从本地:{model_path}加载模型...") |
|
try: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).cuda() |
|
self.model.eval() |
|
self.llm_system_prompt = llm_system_prompt |
|
logging.info("完成本地模型的加载") |
|
except Exception as e: |
|
logging.error(f"加载模型时发生错误: {e}") |
|
raise |
|
|
|
def _call(self, prompt : str, stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any): |
|
|
|
|
|
system_prompt = self.llm_system_prompt |
|
|
|
messages = [("system", system_prompt)] |
|
response, history = self.model.chat(self.tokenizer, prompt , history=messages) |
|
return response |
|
|
|
def stream(self, prompt: str) -> Iterator[str]: |
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").cuda() |
|
streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) |
|
|
|
generation_kwargs = dict( |
|
input_ids=input_ids, |
|
max_new_tokens=2048, |
|
|
|
|
|
|
|
|
|
|
|
streamer=streamer, |
|
) |
|
|
|
thread = Thread(target=self.model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
for idx, new_text in enumerate(streamer): |
|
|
|
if idx > 0: |
|
yield new_text |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "InternLM2" |
|
|
|
class WuleRAG(): |
|
""" |
|
存储检索问答链的对象 |
|
""" |
|
def __init__(self, data_source_dir, db_persist_directory, base_mode, embeddings_model, reranker_model, rag_prompt_template): |
|
|
|
self.llm = base_mode |
|
|
|
|
|
|
|
if not os.path.exists(embeddings_model): |
|
if embeddings_model.endswith("bce-embedding-base_v1"): |
|
|
|
save_dir = os.path.dirname(os.path.dirname(embeddings_model)) |
|
embeddings_model = snapshot_download("maidalun/bce-embedding-base_v1", cache_dir=save_dir, revision='master') |
|
logging.info(f"bce-embedding model not exist, downloading from modelscope \n save to {embeddings_model}") |
|
else: |
|
raise ValueError(f"{embeddings_model} model not exist, please reset or re-download your model.") |
|
|
|
self.embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, |
|
model_kwargs={"device": "cuda"}, |
|
encode_kwargs={"batch_size": 1, "normalize_embeddings": True}) |
|
|
|
|
|
if not os.path.exists(reranker_model): |
|
if reranker_model.endswith("bce-reranker-base_v1"): |
|
|
|
save_dir = os.path.dirname(os.path.dirname(reranker_model)) |
|
reranker_model = snapshot_download("maidalun/bce-reranker-base_v1", cache_dir=save_dir, revision='master') |
|
logging.info(f"reranker_model model not exist, downloading from modelscope \n save to {reranker_model}") |
|
else: |
|
raise ValueError(f"{reranker_model} model not exist, please reset or re-download your model.") |
|
reranker_args = {'model': reranker_model, 'top_n': 5, 'device': 'cuda', "use_fp16": True} |
|
self.reranker = BCERerank(**reranker_args) |
|
vectordb = get_chroma_db(data_source_dir, db_persist_directory, self.embeddings) |
|
|
|
|
|
|
|
self.retriever = vectordb.as_retriever(search_kwargs={"k": 3, "score_threshold": 0.6}, search_type="similarity_score_threshold" ) |
|
|
|
|
|
self.compression_retriever = ContextualCompressionRetriever( |
|
base_compressor=self.reranker, base_retriever=self.retriever |
|
) |
|
|
|
|
|
self.PROMPT = PromptTemplate( |
|
template=rag_prompt_template, input_variables=["context", "question"] |
|
) |
|
|
|
|
|
self.qa_chain = RetrievalQA.from_chain_type( |
|
llm=self.llm, |
|
chain_type="stuff", |
|
retriever=self.compression_retriever, |
|
return_source_documents=True, |
|
chain_type_kwargs={"prompt": self.PROMPT} |
|
) |
|
|
|
def query_stream(self, query: str) -> Iterator[str]: |
|
docs = self.compression_retriever.get_relevant_documents(query) |
|
context = "\n\n".join(doc.page_content for doc in docs) |
|
prompt = self.PROMPT.format(context=context, question=query) |
|
return self.llm.stream(prompt) |
|
|
|
def query(self, question): |
|
""" |
|
调用问答链进行回答,如果没有找到相关文档,则使用模型自身的回答 |
|
#使用示例 |
|
question='黑神话悟空发售时间和团队?主要讲了什么故事?' |
|
result = self.qa_chain({"query": question}) |
|
print(result["result"]) |
|
""" |
|
if not question: |
|
return "请提供个有用的问题。" |
|
|
|
try: |
|
|
|
result = self.qa_chain.invoke({"query": question}) |
|
|
|
|
|
if 'result' in result: |
|
answer = result['result'] |
|
final_answer = re.sub(r'^根据提供的信息,\s?', '', answer, flags=re.M).strip() |
|
return final_answer |
|
else: |
|
logging.error("Error: 'result' field not found in the result.") |
|
return "悟了悟了目前无法提供答案,请稍后再试。" |
|
except Exception as e: |
|
|
|
import traceback |
|
logging.error(f"An error occurred: {e}\n{traceback.format_exc()}") |
|
return "悟了悟了遇到了一些技术问题,正在修复中。" |
|
|
|
|
|
@hydra.main(version_base=None, config_path="../configs", config_name="rag_cfg") |
|
def main(config): |
|
data_source_dir = config.data_source_dir |
|
db_persist_directory = config.db_persist_directory |
|
llm_model = config.llm_model |
|
embeddings_model = config.embeddings_model |
|
reranker_model = config.reranker_model |
|
llm_system_prompt = config.llm_system_prompt |
|
rag_prompt_template = config.rag_prompt_template |
|
|
|
|
|
if not os.path.exists(llm_model): |
|
download_model(llm_model_path = llm_model) |
|
|
|
base_mode = InternLM(model_path=llm_model, llm_system_prompt=llm_system_prompt) |
|
|
|
|
|
wulewule_rag = WuleRAG(data_source_dir, db_persist_directory, base_mode, embeddings_model, reranker_model, rag_prompt_template) |
|
question="""黑神话悟空发售时间和团队?主要讲了什么故事?""" |
|
|
|
if config.stream_response: |
|
logging.info("Streaming response:") |
|
for chunk in wulewule_rag.query_stream(question): |
|
print(chunk, end='', flush=True) |
|
print("\n") |
|
|
|
else: |
|
response = wulewule_rag.query(question) |
|
logging.info(f"question: {question}\n wulewule answer:\n{response}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|