|
import os |
|
import time |
|
import ast |
|
from typing import List, Optional |
|
from pydantic import BaseModel |
|
|
|
from Router.router import Evaluator |
|
from semantic_router.samples import rag_sample, chitchatSample |
|
from utils.pipelines.main import get_last_user_message, add_or_update_system_message, pop_system_message |
|
from blueprints.rag_utils import format_docs |
|
from blueprints.prompts import accurate_rag_prompt, QUERY_PROMPT, evaluator_intent, basic_template, chitchat_prompt |
|
from BM25 import BM25SRetriever |
|
from SafetyChecker import SafetyChecker |
|
from langchain.retrievers import EnsembleRetriever |
|
from BM25 import BM25SRetriever |
|
from semantic_cache.main import SemanticCache |
|
from sentence_transformers import SentenceTransformer |
|
|
|
from langchain.retrievers.multi_query import MultiQueryRetriever |
|
import cohere |
|
from langchain_core.output_parsers import BaseOutputParser |
|
from langchain_cohere import CohereRerank |
|
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever |
|
from langchain_groq import ChatGroq |
|
from langchain_core.runnables import RunnablePassthrough |
|
|
|
|
|
|
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
qdrant_url = os.getenv('URL_QDRANT') |
|
qdrant_api = os.getenv('API_QDRANT') |
|
os.environ["COHERE_API_KEY"] |
|
|
|
|
|
class LineListOutputParser(BaseOutputParser[List[str]]): |
|
"""Output parser for a list of lines.""" |
|
|
|
def parse(self, text: str) -> List[str]: |
|
lines = text.strip().split("\n") |
|
return list(filter(None, lines)) |
|
|
|
|
|
class Pipeline: |
|
class Valves(BaseModel): |
|
|
|
|
|
pipelines: List[str] = [] |
|
|
|
|
|
|
|
|
|
priority: int = 0 |
|
|
|
|
|
pass |
|
|
|
def __init__(self): |
|
self.type = "filter" |
|
self.name = "Filter" |
|
self.embedding = None |
|
self.route = None |
|
self.stsv_db = None |
|
self.gthv_db = None |
|
self.ttts_db = None |
|
self.reranker = None |
|
self.valves = self.Valves(**{"pipelines": ["*"]}) |
|
pass |
|
|
|
def split_context(self, context): |
|
split_index = context.find("User question") |
|
system_prompt = context[:split_index].strip() |
|
user_question = context[split_index:].strip() |
|
user_split_index = user_question.find("<context>") |
|
f_system_prompt = str(system_prompt) +"\n" + str(user_question[user_split_index:]) |
|
return f_system_prompt |
|
|
|
async def on_startup(self): |
|
|
|
print(f"on_startup:{__name__}") |
|
from typing import List |
|
from langchain_community.vectorstores import Qdrant |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
|
self.embedding = SentenceTransformer("dangvantuan/vietnamese-embedding") |
|
HF_EMBEDDING = HuggingFaceEmbeddings(model_name="dangvantuan/vietnamese-embedding") |
|
from qdrant_client import QdrantClient |
|
from langchain_community.vectorstores import Qdrant |
|
|
|
|
|
|
|
|
|
|
|
client = QdrantClient(url="http://localhost:6333") |
|
|
|
gthv = Qdrant(client, collection_name="gioithieuhocvien_db", embeddings= HF_EMBEDDING) |
|
self.gthv_db = gthv.as_retriever() |
|
|
|
stsv = Qdrant(client, collection_name="sotaysinhvien_db", embeddings= HF_EMBEDDING) |
|
self.stsv_db = stsv.as_retriever() |
|
|
|
ttts = Qdrant(client, collection_name="thongtintuyensinh_db", embeddings= HF_EMBEDDING) |
|
self.ttts_db = ttts.as_retriever() |
|
|
|
import pickle |
|
with open('data/thongtintuyensinh.pkl', 'rb') as f: |
|
self.thongtintuyensinh = pickle.load(f) |
|
with open('data/sotaysinhvien.pkl', 'rb') as f: |
|
self.sotaysinhvien = pickle.load(f) |
|
with open('data/gioithieuhocvien.pkl', 'rb') as f: |
|
self.gioithieuhocvien = pickle.load(f) |
|
self.retriever_bm25_tuyensinh = BM25SRetriever.from_documents(self.thongtintuyensinh, k= 5, activate_numba = True) |
|
self.retriever_bm25_sotay = BM25SRetriever.from_documents(self.sotaysinhvien, k= 5, activate_numba = True) |
|
self.retriever_bm25_hocvien = BM25SRetriever.from_documents(self.gioithieuhocvien, k= 5, activate_numba = True) |
|
|
|
self.cache = SemanticCache() |
|
|
|
self.reranker = CohereRerank(model = "rerank-multilingual-v3.0", top_n = 5) |
|
|
|
llm = ChatGroq(model_name="llama3-70b-8192", temperature=0.1,api_key= os.getenv('llm_api_5')) |
|
output_parser = LineListOutputParser() |
|
self.llm_chain = QUERY_PROMPT | llm | output_parser |
|
pass |
|
|
|
async def on_shutdown(self): |
|
|
|
print(f"on_shutdown:{__name__}") |
|
pass |
|
|
|
def get_last_assistant_message(self, messages: List[dict]) -> str: |
|
for message in reversed(messages): |
|
if message["role"] == "assistant": |
|
if isinstance(message["content"], list): |
|
for item in message["content"]: |
|
if item["type"] == "text": |
|
return item["text"] |
|
return message["content"] |
|
return "" |
|
def add_or_update_system_message(self,content: str, messages: List[dict]): |
|
""" |
|
Adds a new system message at the beginning of the messages list |
|
:param msg: The message to be added or appended. |
|
:param messages: The list of message dictionaries. |
|
:return: The updated list of message dictionaries. |
|
""" |
|
|
|
if messages and messages[0].get("role") == "system": |
|
messages[0]["content"] += f"{content}\n" |
|
else: |
|
|
|
messages.insert(0, {"role": "system", "content": content}) |
|
|
|
return messages |
|
|
|
def add_messages(self,content: str, messages: List[dict]): |
|
messages.insert(0, {"role": "system", "content": content}) |
|
return messages |
|
|
|
cache_hit = False |
|
async def inlet(self, body: dict, user: Optional[dict] = None) -> dict: |
|
messages = body.get("messages", []) |
|
print(messages) |
|
user_message = get_last_user_message(messages) |
|
print(user_message) |
|
|
|
|
|
checker = SafetyChecker() |
|
safety_result = checker.check_safety(user_message) |
|
|
|
if safety_result != 'safe' : |
|
print("Safety check :" ,safety_result) |
|
construct_msg = f"Dựa vào thông tin trả lời câu hỏi của người dùng bằng Tiếng Việt\n\n : {safety_result}" |
|
body["messages"] = self.add_messages( |
|
construct_msg, messages) |
|
|
|
print(body) |
|
return body |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cache_result = self.cache.checker(user_message) |
|
if cache_result is not None: |
|
print("###Cache hit!###") |
|
self.cache_hit = True |
|
construct_msg = f"Dựa vào thông tin trả lời câu hỏi của người dùng bằng Tiếng Việt \n\n : {cache_result}" |
|
body["messages"] = self.add_or_update_system_message( |
|
construct_msg, messages) |
|
print(body) |
|
return body |
|
self.cache_hit = False |
|
|
|
print("No cache found! Generation continue") |
|
evaluator = Evaluator(llm="llama3-8b", prompt=evaluator_intent) |
|
output = evaluator.classify_text(user_message) |
|
|
|
retriever = None |
|
|
|
print(f'Câu hỏi người dùng: {user_message}') |
|
|
|
if output and output.result == 'OUT_OF_SCOPE' : |
|
print('OUT OF SCOPE') |
|
construct_msg = f"Dựa vào thông tin trả lời câu hỏi của người dùng bằng Tiếng Việt\n\n : {chitchat_prompt}" |
|
body["messages"] = self.add_or_update_system_message( |
|
construct_msg, messages) |
|
print(body) |
|
return body |
|
|
|
elif output and output.result == 'ASK_QUYDINH' : |
|
print('SO TAY SINH VIEN DB') |
|
retriever = self.stsv_db |
|
retriever_bm25 = self.retriever_bm25_sotay |
|
|
|
elif output and output.result == 'ASK_HOCVIEN' : |
|
print('HOC VIEN DB') |
|
retriever = self.gthv_db |
|
retriever_bm25 = self.retriever_bm25_hocvien |
|
|
|
elif output and output.result == 'ASK_TUYENSINH' : |
|
print('THONG TIN TUYEN SINH DB') |
|
retriever = self.ttts_db |
|
retriever_bm25 = self.retriever_bm25_tuyensinh |
|
|
|
if retriever is not None: |
|
retriever_multi = MultiQueryRetriever( |
|
retriever=retriever, llm_chain=self.llm_chain, parser_key="lines" |
|
) |
|
|
|
|
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[retriever_bm25, retriever_multi], weights=[0.5, 0.5]) |
|
|
|
compression = ContextualCompressionRetriever( |
|
base_compressor=self.reranker, base_retriever=ensemble_retriever |
|
) |
|
rag_chain = ( |
|
{"context": compression | format_docs, "question": RunnablePassthrough()} |
|
| basic_template |
|
) |
|
|
|
|
|
|
|
|
|
rag_prompt = rag_chain.invoke(user_message ).text |
|
system_message = self.split_context(rag_prompt) |
|
body["messages"] = self.add_or_update_system_message( |
|
system_message, messages |
|
) |
|
print(body) |
|
|
|
|
|
return body |
|
else: |
|
print('Retriever is not defined. Check output results and ensure retriever is assigned correctly.') |
|
|
|
|
|
async def outlet(self, body : dict , user : Optional[dict]= None) -> dict : |
|
print("##########################") |
|
messages = body.get("messages", []) |
|
|
|
user_message = get_last_user_message(messages) |
|
print(user_message) |
|
print("########### Câu hỏi vừa hỏi #################") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Outlet Body Input: {body}") |
|
return body |