import os import time from typing import List, Optional from pydantic import BaseModel from schemas import OpenAIChatMessage from semantic_router.route import Route from semantic_router.samples import rag_sample, chitchatSample from utils.pipelines.main import get_last_user_message, add_or_update_system_message, get_last_assistant_message from blueprints.rag_utils import format_docs from blueprints.prompts import accurate_rag_prompt from BM25 import BM25SRetriever from semantic_router import SemanticRouter from SafetyChecker import SafetyChecker from semantic_cache import Cache from sentence_transformers import SentenceTransformer from database_Routing import DB_Router import cohere from langchain_cohere import CohereRerank from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain_core.runnables import RunnablePassthrough from dotenv import load_dotenv load_dotenv() qdrant_url = os.getenv('URL_QDRANT') qdrant_api = os.getenv('API_QDRANT') os.environ["COHERE_API_KEY"] #####Embedding model###### class Pipeline: class Valves(BaseModel): # List target pipeline ids (models) that this filter will be connected to. # If you want to connect this filter to all pipelines, you can set pipelines to ["*"] pipelines: List[str] = [] # Assign a priority level to the filter pipeline. # The priority level determines the order in which the filter pipelines are executed. # The lower the number, the higher the priority. priority: int = 0 # Add your custom parameters/configuration here e.g. API_KEY that you want user to configure etc. pass def __init__(self): self.type = "filter" self.name = "Filter" self.embedding = None self.route = None self.cache_flag = None self.cache_embedding = None self.cache_answer = None self.cache = Cache(embedding= 'keepitreal/vietnamese-sbert') self.stsv_db = None self.gthv_db = None self.ttts_db = None self.reranker = None self.valves = self.Valves(**{"pipelines": ["*"]}) pass def split_context(self, context): split_index = context.find("User question") system_prompt = context[:split_index].strip() user_question = context[split_index:].strip() user_split_index = user_question.find("") f_system_prompt = str(system_prompt) + str(user_question[user_split_index:]) return f_system_prompt async def on_startup(self): # This function is called when the server is started. print(f"on_startup:{__name__}") from typing import List import bs4 from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import WebBaseLoader from langchain_community.vectorstores import Qdrant from langchain_core.runnables import RunnablePassthrough from langchain_qdrant import QdrantVectorStore self.embedding = SentenceTransformer('keepitreal/vietnamese-sbert') from langchain_huggingface import HuggingFaceEmbeddings HF_EMBEDDING = HuggingFaceEmbeddings(model_name="keepitreal/vietnamese-sbert") from qdrant_client import QdrantClient from langchain_community.vectorstores import Qdrant client = QdrantClient( qdrant_url, api_key=qdrant_api ) gthv = Qdrant(client, collection_name="gioithieuhocvien_db", embeddings= HF_EMBEDDING) self.gthv_db = gthv.as_retriever(search_type="mmr", search_kwargs={"k": 20}) stsv = Qdrant(client, collection_name="sotaysinhvien_db", embeddings= HF_EMBEDDING) self.stsv_db = stsv.as_retriever(search_type="mmr", search_kwargs={"k": 20}) ttts = Qdrant(client, collection_name="thongtintuyensinh_db", embeddings= HF_EMBEDDING) self.ttts_db = ttts.as_retriever(search_type="mmr", search_kwargs={"k": 20}) self.reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 10) pass async def on_shutdown(self): # This function is called when the server is stopped. print(f"on_shutdown:{__name__}") pass async def inlet(self, body: dict, user: Optional[dict] = None) -> dict: messages = body.get("messages", []) user_message = get_last_user_message(messages) print(user_message) #####Router##### MTA_ROUTE_NAME = 'mta' CHITCHAT_ROUTE_NAME = 'chitchat' mtaRoute = Route(name=MTA_ROUTE_NAME, samples=rag_sample) chitchatRoute = Route(name=CHITCHAT_ROUTE_NAME, samples=chitchatSample) router = SemanticRouter(self.embedding, routes=[mtaRoute, chitchatRoute]) cache_result = self.cache.cached_hit(question= user_message) if isinstance(cache_result, str): print(f"Cache hit!") print(f"Answer: {cache_result}") else: self.cache_embedding, self.cache_flag = cache_result guidedRoute = router.guide(user_message)[1] if guidedRoute == CHITCHAT_ROUTE_NAME : return body elif guidedRoute == MTA_ROUTE_NAME : router = DB_Router() result = router.route_query(user_message).map_db() if result == "stsv" : print("Routing to so tay sinh vien") retriever = self.stsv_db elif result == "gthv" : print("Routing to gioi thieu hoc vien") retriever = self.gthv_db elif result == "ttts" : print("Routing to thong tin tuyen sinh") retriever = self.ttts_db else : print("No routing, no RAG need") return body compression = ContextualCompressionRetriever( base_compressor=self.reranker, base_retriever=retriever ) rag_chain = ( {"context": compression | format_docs, "question": RunnablePassthrough()} | accurate_rag_prompt ) rag_prompt = rag_chain.invoke(user_message).text system_message = self.split_context(rag_prompt) body["messages"] = add_or_update_system_message( system_message, messages ) return body async def outlet(self, body : dict , user : Optional[dict]= None) -> dict : print(f"outlet:{__name__}") print(f"Outlet Body Input: {body}") messages = body.get("messages", []) lastest_question = get_last_user_message(messages) lastest_answer = get_last_assistant_message(messages) if self.cache_flag: cached_answer = self.cache.cache_miss(lastest_question, self.cache_embedding, lastest_answer) self.cache_embedding = None print(f"Cache miss") print(f"New answer added to cache: {cached_answer}") return body