File size: 7,059 Bytes
74b1bac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import os
import time
from typing import List, Optional
from pydantic import BaseModel
from schemas import OpenAIChatMessage
from semantic_router.route import Route
from semantic_router.samples import rag_sample, chitchatSample
from utils.pipelines.main import get_last_user_message, add_or_update_system_message, get_last_assistant_message
from blueprints.rag_utils import format_docs
from blueprints.prompts import accurate_rag_prompt
from BM25 import BM25SRetriever
from semantic_router import SemanticRouter
from SafetyChecker import SafetyChecker
from semantic_cache import Cache
from sentence_transformers import SentenceTransformer
from database_Routing import DB_Router
import cohere
from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_core.runnables import RunnablePassthrough
from dotenv import load_dotenv
load_dotenv()
qdrant_url = os.getenv('URL_QDRANT')
qdrant_api = os.getenv('API_QDRANT')
os.environ["COHERE_API_KEY"]
#####Embedding model######
class Pipeline:
class Valves(BaseModel):
# List target pipeline ids (models) that this filter will be connected to.
# If you want to connect this filter to all pipelines, you can set pipelines to ["*"]
pipelines: List[str] = []
# Assign a priority level to the filter pipeline.
# The priority level determines the order in which the filter pipelines are executed.
# The lower the number, the higher the priority.
priority: int = 0
# Add your custom parameters/configuration here e.g. API_KEY that you want user to configure etc.
pass
def __init__(self):
self.type = "filter"
self.name = "Filter"
self.embedding = None
self.route = None
self.cache_flag = None
self.cache_embedding = None
self.cache_answer = None
self.cache = Cache(embedding= 'keepitreal/vietnamese-sbert')
self.stsv_db = None
self.gthv_db = None
self.ttts_db = None
self.reranker = None
self.valves = self.Valves(**{"pipelines": ["*"]})
pass
def split_context(self, context):
split_index = context.find("User question")
system_prompt = context[:split_index].strip()
user_question = context[split_index:].strip()
user_split_index = user_question.find("<context>")
f_system_prompt = str(system_prompt) + str(user_question[user_split_index:])
return f_system_prompt
async def on_startup(self):
# This function is called when the server is started.
print(f"on_startup:{__name__}")
from typing import List
import bs4
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Qdrant
from langchain_core.runnables import RunnablePassthrough
from langchain_qdrant import QdrantVectorStore
self.embedding = SentenceTransformer('keepitreal/vietnamese-sbert')
from langchain_huggingface import HuggingFaceEmbeddings
HF_EMBEDDING = HuggingFaceEmbeddings(model_name="keepitreal/vietnamese-sbert")
from qdrant_client import QdrantClient
from langchain_community.vectorstores import Qdrant
client = QdrantClient(
qdrant_url,
api_key=qdrant_api
)
gthv = Qdrant(client, collection_name="gioithieuhocvien_db", embeddings= HF_EMBEDDING)
self.gthv_db = gthv.as_retriever(search_type="mmr", search_kwargs={"k": 20})
stsv = Qdrant(client, collection_name="sotaysinhvien_db", embeddings= HF_EMBEDDING)
self.stsv_db = stsv.as_retriever(search_type="mmr", search_kwargs={"k": 20})
ttts = Qdrant(client, collection_name="thongtintuyensinh_db", embeddings= HF_EMBEDDING)
self.ttts_db = ttts.as_retriever(search_type="mmr", search_kwargs={"k": 20})
self.reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 10)
pass
async def on_shutdown(self):
# This function is called when the server is stopped.
print(f"on_shutdown:{__name__}")
pass
async def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
messages = body.get("messages", [])
user_message = get_last_user_message(messages)
print(user_message)
#####Router#####
MTA_ROUTE_NAME = 'mta'
CHITCHAT_ROUTE_NAME = 'chitchat'
mtaRoute = Route(name=MTA_ROUTE_NAME, samples=rag_sample)
chitchatRoute = Route(name=CHITCHAT_ROUTE_NAME, samples=chitchatSample)
router = SemanticRouter(self.embedding, routes=[mtaRoute, chitchatRoute])
cache_result = self.cache.cached_hit(question= user_message)
if isinstance(cache_result, str):
print(f"Cache hit!")
print(f"Answer: {cache_result}")
else:
self.cache_embedding, self.cache_flag = cache_result
guidedRoute = router.guide(user_message)[1]
if guidedRoute == CHITCHAT_ROUTE_NAME :
return body
elif guidedRoute == MTA_ROUTE_NAME :
router = DB_Router()
result = router.route_query(user_message).map_db()
if result == "stsv" :
print("Routing to so tay sinh vien")
retriever = self.stsv_db
elif result == "gthv" :
print("Routing to gioi thieu hoc vien")
retriever = self.gthv_db
elif result == "ttts" :
print("Routing to thong tin tuyen sinh")
retriever = self.ttts_db
else :
print("No routing, no RAG need")
return body
compression = ContextualCompressionRetriever(
base_compressor=self.reranker, base_retriever=retriever
)
rag_chain = (
{"context": compression | format_docs, "question": RunnablePassthrough()}
| accurate_rag_prompt
)
rag_prompt = rag_chain.invoke(user_message).text
system_message = self.split_context(rag_prompt)
body["messages"] = add_or_update_system_message(
system_message, messages
)
return body
async def outlet(self, body : dict , user : Optional[dict]= None) -> dict :
print(f"outlet:{__name__}")
print(f"Outlet Body Input: {body}")
messages = body.get("messages", [])
lastest_question = get_last_user_message(messages)
lastest_answer = get_last_assistant_message(messages)
if self.cache_flag:
cached_answer = self.cache.cache_miss(lastest_question, self.cache_embedding, lastest_answer)
self.cache_embedding = None
print(f"Cache miss")
print(f"New answer added to cache: {cached_answer}")
return body |