|
import os |
|
import time |
|
from typing import List, Optional |
|
from pydantic import BaseModel |
|
from schemas import OpenAIChatMessage |
|
from semantic_router.route import Route |
|
from semantic_router.samples import rag_sample, chitchatSample |
|
from utils.pipelines.main import get_last_user_message, add_or_update_system_message, get_last_assistant_message |
|
from blueprints.rag_utils import format_docs |
|
from blueprints.prompts import accurate_rag_prompt |
|
from BM25 import BM25SRetriever |
|
from semantic_router import SemanticRouter |
|
from SafetyChecker import SafetyChecker |
|
from semantic_cache import Cache |
|
from sentence_transformers import SentenceTransformer |
|
from database_Routing import DB_Router |
|
import cohere |
|
from langchain_cohere import CohereRerank |
|
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever |
|
from langchain_core.runnables import RunnablePassthrough |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
qdrant_url = os.getenv('URL_QDRANT') |
|
qdrant_api = os.getenv('API_QDRANT') |
|
os.environ["COHERE_API_KEY"] |
|
|
|
|
|
|
|
|
|
class Pipeline: |
|
class Valves(BaseModel): |
|
|
|
|
|
pipelines: List[str] = [] |
|
|
|
|
|
|
|
|
|
priority: int = 0 |
|
|
|
|
|
pass |
|
|
|
def __init__(self): |
|
self.type = "filter" |
|
self.name = "Filter" |
|
self.embedding = None |
|
self.route = None |
|
self.cache_flag = None |
|
self.cache_embedding = None |
|
self.cache_answer = None |
|
self.cache = Cache(embedding= 'keepitreal/vietnamese-sbert') |
|
self.stsv_db = None |
|
self.gthv_db = None |
|
self.ttts_db = None |
|
self.reranker = None |
|
self.valves = self.Valves(**{"pipelines": ["*"]}) |
|
pass |
|
|
|
def split_context(self, context): |
|
split_index = context.find("User question") |
|
system_prompt = context[:split_index].strip() |
|
user_question = context[split_index:].strip() |
|
user_split_index = user_question.find("<context>") |
|
f_system_prompt = str(system_prompt) + str(user_question[user_split_index:]) |
|
return f_system_prompt |
|
|
|
async def on_startup(self): |
|
|
|
print(f"on_startup:{__name__}") |
|
from typing import List |
|
import bs4 |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.document_loaders import WebBaseLoader |
|
from langchain_community.vectorstores import Qdrant |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_qdrant import QdrantVectorStore |
|
self.embedding = SentenceTransformer('keepitreal/vietnamese-sbert') |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
HF_EMBEDDING = HuggingFaceEmbeddings(model_name="keepitreal/vietnamese-sbert") |
|
from qdrant_client import QdrantClient |
|
from langchain_community.vectorstores import Qdrant |
|
|
|
client = QdrantClient( |
|
qdrant_url, |
|
api_key=qdrant_api |
|
) |
|
|
|
gthv = Qdrant(client, collection_name="gioithieuhocvien_db", embeddings= HF_EMBEDDING) |
|
self.gthv_db = gthv.as_retriever(search_type="mmr", search_kwargs={"k": 20}) |
|
|
|
stsv = Qdrant(client, collection_name="sotaysinhvien_db", embeddings= HF_EMBEDDING) |
|
self.stsv_db = stsv.as_retriever(search_type="mmr", search_kwargs={"k": 20}) |
|
|
|
ttts = Qdrant(client, collection_name="thongtintuyensinh_db", embeddings= HF_EMBEDDING) |
|
|
|
self.ttts_db = ttts.as_retriever(search_type="mmr", search_kwargs={"k": 20}) |
|
|
|
self.reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 10) |
|
pass |
|
|
|
async def on_shutdown(self): |
|
|
|
print(f"on_shutdown:{__name__}") |
|
pass |
|
|
|
|
|
async def inlet(self, body: dict, user: Optional[dict] = None) -> dict: |
|
messages = body.get("messages", []) |
|
user_message = get_last_user_message(messages) |
|
print(user_message) |
|
|
|
MTA_ROUTE_NAME = 'mta' |
|
CHITCHAT_ROUTE_NAME = 'chitchat' |
|
mtaRoute = Route(name=MTA_ROUTE_NAME, samples=rag_sample) |
|
chitchatRoute = Route(name=CHITCHAT_ROUTE_NAME, samples=chitchatSample) |
|
router = SemanticRouter(self.embedding, routes=[mtaRoute, chitchatRoute]) |
|
|
|
|
|
cache_result = self.cache.cached_hit(question= user_message) |
|
if isinstance(cache_result, str): |
|
print(f"Cache hit!") |
|
print(f"Answer: {cache_result}") |
|
else: |
|
self.cache_embedding, self.cache_flag = cache_result |
|
|
|
guidedRoute = router.guide(user_message)[1] |
|
if guidedRoute == CHITCHAT_ROUTE_NAME : |
|
return body |
|
|
|
elif guidedRoute == MTA_ROUTE_NAME : |
|
router = DB_Router() |
|
result = router.route_query(user_message).map_db() |
|
if result == "stsv" : |
|
print("Routing to so tay sinh vien") |
|
retriever = self.stsv_db |
|
elif result == "gthv" : |
|
print("Routing to gioi thieu hoc vien") |
|
retriever = self.gthv_db |
|
elif result == "ttts" : |
|
print("Routing to thong tin tuyen sinh") |
|
retriever = self.ttts_db |
|
else : |
|
print("No routing, no RAG need") |
|
return body |
|
|
|
compression = ContextualCompressionRetriever( |
|
base_compressor=self.reranker, base_retriever=retriever |
|
) |
|
rag_chain = ( |
|
{"context": compression | format_docs, "question": RunnablePassthrough()} |
|
| accurate_rag_prompt |
|
) |
|
|
|
rag_prompt = rag_chain.invoke(user_message).text |
|
system_message = self.split_context(rag_prompt) |
|
body["messages"] = add_or_update_system_message( |
|
system_message, messages |
|
) |
|
return body |
|
|
|
async def outlet(self, body : dict , user : Optional[dict]= None) -> dict : |
|
print(f"outlet:{__name__}") |
|
print(f"Outlet Body Input: {body}") |
|
messages = body.get("messages", []) |
|
lastest_question = get_last_user_message(messages) |
|
lastest_answer = get_last_assistant_message(messages) |
|
if self.cache_flag: |
|
cached_answer = self.cache.cache_miss(lastest_question, self.cache_embedding, lastest_answer) |
|
self.cache_embedding = None |
|
print(f"Cache miss") |
|
print(f"New answer added to cache: {cached_answer}") |
|
|
|
return body |