Graduation / pipelines /examples /production_ready_rag.py
DuyTa's picture
Upload folder using huggingface_hub
74b1bac verified
raw
history blame
7.06 kB
import os
import time
from typing import List, Optional
from pydantic import BaseModel
from schemas import OpenAIChatMessage
from semantic_router.route import Route
from semantic_router.samples import rag_sample, chitchatSample
from utils.pipelines.main import get_last_user_message, add_or_update_system_message, get_last_assistant_message
from blueprints.rag_utils import format_docs
from blueprints.prompts import accurate_rag_prompt
from BM25 import BM25SRetriever
from semantic_router import SemanticRouter
from SafetyChecker import SafetyChecker
from semantic_cache import Cache
from sentence_transformers import SentenceTransformer
from database_Routing import DB_Router
import cohere
from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_core.runnables import RunnablePassthrough
from dotenv import load_dotenv
load_dotenv()
qdrant_url = os.getenv('URL_QDRANT')
qdrant_api = os.getenv('API_QDRANT')
os.environ["COHERE_API_KEY"]
#####Embedding model######
class Pipeline:
class Valves(BaseModel):
# List target pipeline ids (models) that this filter will be connected to.
# If you want to connect this filter to all pipelines, you can set pipelines to ["*"]
pipelines: List[str] = []
# Assign a priority level to the filter pipeline.
# The priority level determines the order in which the filter pipelines are executed.
# The lower the number, the higher the priority.
priority: int = 0
# Add your custom parameters/configuration here e.g. API_KEY that you want user to configure etc.
pass
def __init__(self):
self.type = "filter"
self.name = "Filter"
self.embedding = None
self.route = None
self.cache_flag = None
self.cache_embedding = None
self.cache_answer = None
self.cache = Cache(embedding= 'keepitreal/vietnamese-sbert')
self.stsv_db = None
self.gthv_db = None
self.ttts_db = None
self.reranker = None
self.valves = self.Valves(**{"pipelines": ["*"]})
pass
def split_context(self, context):
split_index = context.find("User question")
system_prompt = context[:split_index].strip()
user_question = context[split_index:].strip()
user_split_index = user_question.find("<context>")
f_system_prompt = str(system_prompt) + str(user_question[user_split_index:])
return f_system_prompt
async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
from typing import List
import bs4
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Qdrant
from langchain_core.runnables import RunnablePassthrough
from langchain_qdrant import QdrantVectorStore
self.embedding = SentenceTransformer('keepitreal/vietnamese-sbert')
from langchain_huggingface import HuggingFaceEmbeddings
HF_EMBEDDING = HuggingFaceEmbeddings(model_name="keepitreal/vietnamese-sbert")
from qdrant_client import QdrantClient
from langchain_community.vectorstores import Qdrant
client = QdrantClient(
qdrant_url,
api_key=qdrant_api
)
gthv = Qdrant(client, collection_name="gioithieuhocvien_db", embeddings= HF_EMBEDDING)
self.gthv_db = gthv.as_retriever(search_type="mmr", search_kwargs={"k": 20})
stsv = Qdrant(client, collection_name="sotaysinhvien_db", embeddings= HF_EMBEDDING)
self.stsv_db = stsv.as_retriever(search_type="mmr", search_kwargs={"k": 20})
ttts = Qdrant(client, collection_name="thongtintuyensinh_db", embeddings= HF_EMBEDDING)
self.ttts_db = ttts.as_retriever(search_type="mmr", search_kwargs={"k": 20})
self.reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 10)
pass
async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
pass
async def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
messages = body.get("messages", [])
user_message = get_last_user_message(messages)
print(user_message)
#####Router#####
MTA_ROUTE_NAME = 'mta'
CHITCHAT_ROUTE_NAME = 'chitchat'
mtaRoute = Route(name=MTA_ROUTE_NAME, samples=rag_sample)
chitchatRoute = Route(name=CHITCHAT_ROUTE_NAME, samples=chitchatSample)
router = SemanticRouter(self.embedding, routes=[mtaRoute, chitchatRoute])
cache_result = self.cache.cached_hit(question= user_message)
if isinstance(cache_result, str):
print(f"Cache hit!")
print(f"Answer: {cache_result}")
else:
self.cache_embedding, self.cache_flag = cache_result
guidedRoute = router.guide(user_message)[1]
if guidedRoute == CHITCHAT_ROUTE_NAME :
return body
elif guidedRoute == MTA_ROUTE_NAME :
router = DB_Router()
result = router.route_query(user_message).map_db()
if result == "stsv" :
print("Routing to so tay sinh vien")
retriever = self.stsv_db
elif result == "gthv" :
print("Routing to gioi thieu hoc vien")
retriever = self.gthv_db
elif result == "ttts" :
print("Routing to thong tin tuyen sinh")
retriever = self.ttts_db
else :
print("No routing, no RAG need")
return body
compression = ContextualCompressionRetriever(
base_compressor=self.reranker, base_retriever=retriever
)
rag_chain = (
{"context": compression | format_docs, "question": RunnablePassthrough()}
| accurate_rag_prompt
)
rag_prompt = rag_chain.invoke(user_message).text
system_message = self.split_context(rag_prompt)
body["messages"] = add_or_update_system_message(
system_message, messages
)
return body
async def outlet(self, body : dict , user : Optional[dict]= None) -> dict :
print(f"outlet:{__name__}")
print(f"Outlet Body Input: {body}")
messages = body.get("messages", [])
lastest_question = get_last_user_message(messages)
lastest_answer = get_last_assistant_message(messages)
if self.cache_flag:
cached_answer = self.cache.cache_miss(lastest_question, self.cache_embedding, lastest_answer)
self.cache_embedding = None
print(f"Cache miss")
print(f"New answer added to cache: {cached_answer}")
return body