File size: 7,059 Bytes
74b1bac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import time
from typing import List, Optional
from pydantic import BaseModel
from schemas import OpenAIChatMessage
from semantic_router.route import Route
from semantic_router.samples import rag_sample, chitchatSample
from utils.pipelines.main import get_last_user_message, add_or_update_system_message, get_last_assistant_message
from blueprints.rag_utils import format_docs
from blueprints.prompts import accurate_rag_prompt
from BM25 import BM25SRetriever
from semantic_router import SemanticRouter
from SafetyChecker import SafetyChecker
from semantic_cache import Cache
from sentence_transformers import SentenceTransformer
from database_Routing import DB_Router
import cohere
from langchain_cohere import CohereRerank
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_core.runnables import RunnablePassthrough
from dotenv import load_dotenv
load_dotenv()
qdrant_url = os.getenv('URL_QDRANT')
qdrant_api = os.getenv('API_QDRANT')
os.environ["COHERE_API_KEY"]
#####Embedding model######



class Pipeline:
    class Valves(BaseModel):
        # List target pipeline ids (models) that this filter will be connected to.
        # If you want to connect this filter to all pipelines, you can set pipelines to ["*"]
        pipelines: List[str] = []

        # Assign a priority level to the filter pipeline.
        # The priority level determines the order in which the filter pipelines are executed.
        # The lower the number, the higher the priority.
        priority: int = 0

        # Add your custom parameters/configuration here e.g. API_KEY that you want user to configure etc.
        pass

    def __init__(self):
        self.type = "filter"
        self.name = "Filter"
        self.embedding = None
        self.route = None
        self.cache_flag = None
        self.cache_embedding = None
        self.cache_answer = None
        self.cache = Cache(embedding= 'keepitreal/vietnamese-sbert')
        self.stsv_db = None 
        self.gthv_db = None 
        self.ttts_db = None
        self.reranker = None
        self.valves = self.Valves(**{"pipelines": ["*"]})
        pass

    def split_context(self, context):
        split_index = context.find("User question")
        system_prompt = context[:split_index].strip()
        user_question = context[split_index:].strip()
        user_split_index = user_question.find("<context>")
        f_system_prompt = str(system_prompt) + str(user_question[user_split_index:])
        return f_system_prompt
    
    async def on_startup(self):
        # This function is called when the server is started.
        print(f"on_startup:{__name__}")
        from typing import List
        import bs4
        from langchain.text_splitter import RecursiveCharacterTextSplitter
        from langchain_community.document_loaders import WebBaseLoader
        from langchain_community.vectorstores import Qdrant
        from langchain_core.runnables import RunnablePassthrough
        from langchain_qdrant import QdrantVectorStore
        self.embedding = SentenceTransformer('keepitreal/vietnamese-sbert')
        from langchain_huggingface import HuggingFaceEmbeddings
        HF_EMBEDDING = HuggingFaceEmbeddings(model_name="keepitreal/vietnamese-sbert")
        from qdrant_client import QdrantClient
        from langchain_community.vectorstores import Qdrant

        client = QdrantClient(
        qdrant_url,
            api_key=qdrant_api
        )

        gthv = Qdrant(client, collection_name="gioithieuhocvien_db", embeddings= HF_EMBEDDING)
        self.gthv_db = gthv.as_retriever(search_type="mmr", search_kwargs={"k": 20})

        stsv = Qdrant(client, collection_name="sotaysinhvien_db", embeddings= HF_EMBEDDING)
        self.stsv_db = stsv.as_retriever(search_type="mmr", search_kwargs={"k": 20})

        ttts = Qdrant(client, collection_name="thongtintuyensinh_db", embeddings= HF_EMBEDDING)

        self.ttts_db = ttts.as_retriever(search_type="mmr", search_kwargs={"k": 20})
        
        self.reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 10)
        pass

    async def on_shutdown(self):
        # This function is called when the server is stopped.
        print(f"on_shutdown:{__name__}")
        pass
    

    async def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
        messages = body.get("messages", [])
        user_message = get_last_user_message(messages)
        print(user_message)
        #####Router#####
        MTA_ROUTE_NAME = 'mta'
        CHITCHAT_ROUTE_NAME = 'chitchat'
        mtaRoute = Route(name=MTA_ROUTE_NAME, samples=rag_sample)
        chitchatRoute = Route(name=CHITCHAT_ROUTE_NAME, samples=chitchatSample)
        router = SemanticRouter(self.embedding, routes=[mtaRoute, chitchatRoute])
        
        
        cache_result = self.cache.cached_hit(question= user_message)
        if isinstance(cache_result, str):
            print(f"Cache hit!")
            print(f"Answer: {cache_result}")
        else:
            self.cache_embedding, self.cache_flag = cache_result
            
        guidedRoute = router.guide(user_message)[1]
        if guidedRoute == CHITCHAT_ROUTE_NAME :
            return body

        elif guidedRoute == MTA_ROUTE_NAME :
            router = DB_Router()
            result = router.route_query(user_message).map_db()
            if result == "stsv" :
                print("Routing to so tay sinh vien")
                retriever = self.stsv_db
            elif result == "gthv" :
                print("Routing to gioi thieu hoc vien")
                retriever = self.gthv_db
            elif result == "ttts" :
                print("Routing to thong tin tuyen sinh")
                retriever = self.ttts_db
            else :
                print("No routing, no RAG need") 
                return body
            
            compression = ContextualCompressionRetriever(
            base_compressor=self.reranker, base_retriever=retriever
            )
        rag_chain = (
            {"context": compression | format_docs, "question": RunnablePassthrough()}
            | accurate_rag_prompt
        )

        rag_prompt = rag_chain.invoke(user_message).text
        system_message = self.split_context(rag_prompt)
        body["messages"] = add_or_update_system_message(
                     system_message, messages
                )
        return body
    
    async def outlet(self, body : dict , user : Optional[dict]= None) -> dict :
        print(f"outlet:{__name__}")
        print(f"Outlet Body Input: {body}")
        messages = body.get("messages", [])
        lastest_question = get_last_user_message(messages)
        lastest_answer = get_last_assistant_message(messages)
        if self.cache_flag:
            cached_answer = self.cache.cache_miss(lastest_question, self.cache_embedding, lastest_answer)
            self.cache_embedding = None
            print(f"Cache miss")
            print(f"New answer added to cache: {cached_answer}")
        
        return body