Graduation / pipelines /examples /production_rag.py
DuyTa's picture
Upload folder using huggingface_hub
74b1bac verified
raw
history blame
12.2 kB
import os
import time
import ast
from typing import List, Optional
from pydantic import BaseModel
# from semantic_router.route import Route
from Router.router import Evaluator
from semantic_router.samples import rag_sample, chitchatSample
from utils.pipelines.main import get_last_user_message, add_or_update_system_message, pop_system_message
from blueprints.rag_utils import format_docs
from blueprints.prompts import accurate_rag_prompt, QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt
from BM25 import BM25SRetriever
from SafetyChecker import SafetyChecker
from langchain.retrievers import EnsembleRetriever
from BM25 import BM25SRetriever
from semantic_cache.main import SemanticCache
from sentence_transformers import SentenceTransformer
# from database_Routing import DB_Router
from langchain.retrievers.multi_query import MultiQueryRetriever
import cohere
from langchain_core.output_parsers import BaseOutputParser
from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_groq import ChatGroq
from langchain_core.runnables import RunnablePassthrough
# import logging
# logging.basicConfig(
# level=print,
# format='%(levelname)s - %(message)s')
from dotenv import load_dotenv
load_dotenv()
qdrant_url = os.getenv('URL_QDRANT')
qdrant_api = os.getenv('API_QDRANT')
os.environ["COHERE_API_KEY"]
#####Embedding model######
class LineListOutputParser(BaseOutputParser[List[str]]):
"""Output parser for a list of lines."""
def parse(self, text: str) -> List[str]:
lines = text.strip().split("\n")
return list(filter(None, lines)) # Remove empty lines
class Pipeline:
class Valves(BaseModel):
# List target pipeline ids (models) that this filter will be connected to.
# If you want to connect this filter to all pipelines, you can set pipelines to ["*"]
pipelines: List[str] = []
# Assign a priority level to the filter pipeline.
# The priority level determines the order in which the filter pipelines are executed.
# The lower the number, the higher the priority.
priority: int = 0
# Add your custom parameters/configuration here e.g. API_KEY that you want user to configure etc.
pass
def __init__(self):
self.type = "filter"
self.name = "Filter"
self.embedding = None
self.route = None
self.stsv_db = None
self.gthv_db = None
self.ttts_db = None
self.reranker = None
self.valves = self.Valves(**{"pipelines": ["*"]})
pass
def split_context(self, context):
split_index = context.find("User question")
system_prompt = context[:split_index].strip()
user_question = context[split_index:].strip()
user_split_index = user_question.find("<context>")
f_system_prompt = str(system_prompt) +"\n" + str(user_question[user_split_index:])
return f_system_prompt
async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
from typing import List
from langchain_community.vectorstores import Qdrant
from langchain_huggingface import HuggingFaceEmbeddings
self.embedding = SentenceTransformer("dangvantuan/vietnamese-embedding")
HF_EMBEDDING = HuggingFaceEmbeddings(model_name="dangvantuan/vietnamese-embedding")
from qdrant_client import QdrantClient
from langchain_community.vectorstores import Qdrant
# client = QdrantClient(
# qdrant_url,
# api_key=qdrant_api
# )
client = QdrantClient(url="http://localhost:6333")
gthv = Qdrant(client, collection_name="gioithieuhocvien_db", embeddings= HF_EMBEDDING)
self.gthv_db = gthv.as_retriever()
stsv = Qdrant(client, collection_name="sotaysinhvien_db", embeddings= HF_EMBEDDING)
self.stsv_db = stsv.as_retriever()
ttts = Qdrant(client, collection_name="thongtintuyensinh_db", embeddings= HF_EMBEDDING)
self.ttts_db = ttts.as_retriever()
import pickle
with open('data/thongtintuyensinh.pkl', 'rb') as f:
self.thongtintuyensinh = pickle.load(f)
with open('data/sotaysinhvien.pkl', 'rb') as f:
self.sotaysinhvien = pickle.load(f)
with open('data/gioithieuhocvien.pkl', 'rb') as f:
self.gioithieuhocvien = pickle.load(f)
self.retriever_bm25_tuyensinh = BM25SRetriever.from_documents(self.thongtintuyensinh, k= 5, activate_numba = True)
self.retriever_bm25_sotay = BM25SRetriever.from_documents(self.sotaysinhvien, k= 5, activate_numba = True)
self.retriever_bm25_hocvien = BM25SRetriever.from_documents(self.gioithieuhocvien, k= 5, activate_numba = True)
self.cache = SemanticCache()
self.reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 5)
llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_5'))
output_parser = LineListOutputParser()
self.llm_chain = QUERY_PROMPT | llm | output_parser
pass
async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
pass
def get_last_assistant_message(self, messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "assistant":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return ""
def add_or_update_system_message(self,content: str, messages: List[dict]):
"""
Adds a new system message at the beginning of the messages list
:param msg: The message to be added or appended.
:param messages: The list of message dictionaries.
:return: The updated list of message dictionaries.
"""
if messages and messages[0].get("role") == "system":
messages[0]["content"] += f"{content}\n"
else:
# Insert at the beginning
messages.insert(0, {"role": "system", "content": content})
return messages
def add_messages(self,content: str, messages: List[dict]):
messages.insert(0, {"role": "system", "content": content})
return messages
cache_hit = False
async def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
messages = body.get("messages", [])
print(messages)
user_message = get_last_user_message(messages)
print(user_message)
#####guard#####
checker = SafetyChecker()
safety_result = checker.check_safety(user_message)
if safety_result != 'safe' :
print("Safety check :" ,safety_result)
construct_msg = f"Dựa vào thông tin trả lời câu hỏi của người dùng bằng Tiếng Việt\n\n : {safety_result}"
body["messages"] = self.add_messages(
construct_msg, messages)
print(body)
return body
#####Router#####
# MTA_ROUTE_NAME = 'mta'
# CHITCHAT_ROUTE_NAME = 'chitchat'
# mtaRoute = Route(name=MTA_ROUTE_NAME, samples=rag_sample)
# chitchatRoute = Route(name=CHITCHAT_ROUTE_NAME, samples=chitchatSample)
# router = SemanticRouter(self.embedding, routes=[mtaRoute, chitchatRoute])
# guidedRoute = router.guide(user_message)[1]
# print("Semantic Router :", guidedRoute)
cache_result = self.cache.checker(user_message)
if cache_result is not None:
print("###Cache hit!###")
self.cache_hit = True
construct_msg = f"Dựa vào thông tin trả lời câu hỏi của người dùng bằng Tiếng Việt \n\n : {cache_result}"
body["messages"] = self.add_or_update_system_message(
construct_msg, messages)
print(body)
return body
self.cache_hit = False
print("No cache found! Generation continue")
evaluator = Evaluator(llm="llama3-8b", prompt=evaluator_intent)
output = evaluator.classify_text(user_message)
retriever = None
print(f'Câu hỏi người dùng: {user_message}')
# print(output.result)
if output and output.result == 'OUT_OF_SCOPE' :
print('OUT OF SCOPE')
construct_msg = f"Dựa vào thông tin trả lời câu hỏi của người dùng bằng Tiếng Việt\n\n : {chitchat_prompt}"
body["messages"] = self.add_or_update_system_message(
construct_msg, messages)
print(body)
return body
elif output and output.result == 'ASK_QUYDINH' :
print('SO TAY SINH VIEN DB')
retriever = self.stsv_db
retriever_bm25 = self.retriever_bm25_sotay
# db = self.sotaysinhvien
elif output and output.result == 'ASK_HOCVIEN' :
print('HOC VIEN DB')
retriever = self.gthv_db
retriever_bm25 = self.retriever_bm25_hocvien
# db = self.gioithieuhocvien
elif output and output.result == 'ASK_TUYENSINH' :
print('THONG TIN TUYEN SINH DB')
retriever = self.ttts_db
retriever_bm25 = self.retriever_bm25_tuyensinh
# db = self.thongtintuyensinh
if retriever is not None:
retriever_multi = MultiQueryRetriever(
retriever=retriever, llm_chain=self.llm_chain, parser_key="lines"
)
# retriever_bm25 = BM25SRetriever.from_documents(db, k= 5, activate_numba = True)
ensemble_retriever = EnsembleRetriever(
retrievers=[retriever_bm25, retriever_multi], weights=[0.5, 0.5])
compression = ContextualCompressionRetriever(
base_compressor=self.reranker, base_retriever=ensemble_retriever
)
rag_chain = (
{"context": compression | format_docs, "question": RunnablePassthrough()}
| basic_template
)
# last_asisstant = self.get_last_assistant_message(messages)
# print("###################### last asisstant")
# print(last_asisstant)
#rag_prompt = rag_chain.invoke(user_message + "\n" + last_asisstant).text
rag_prompt = rag_chain.invoke(user_message ).text
system_message = self.split_context(rag_prompt)
body["messages"] = self.add_or_update_system_message(
system_message, messages
)
print(body)
# self.cache.add_to_cache(question, response_text)
return body
else:
print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.')
async def outlet(self, body : dict , user : Optional[dict]= None) -> dict :
print("##########################")
messages = body.get("messages", [])
# print(messages)
user_message = get_last_user_message(messages)
print(user_message)
print("########### Câu hỏi vừa hỏi #################")
# output_list = ast.literal_eval(user_message)
# print(output_list)
# print(output_list[-2]['content'])
# print(output_list[-1]['content'])
# print(f"outlet:{__name__}")
# print(f'##### Cache hit = {self.cache_hit}')
# if body and self.cache_hit == False:
# print(body['messages'][-2]['content'])
# print(body['messages'][-1]['content'])
# self.cache.add_to_cache(body['messages'][-2]['content'], body['messages'][-1]['content'])
print(f"Outlet Body Input: {body}")
return body