|
import os |
|
from typing import Literal, Optional, Union |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain_groq import ChatGroq |
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
|
class MultiDataSourceSearch(BaseModel): |
|
"""Search over multi databases about a university.""" |
|
|
|
db_name: Union[ |
|
Literal["Student's book"], |
|
Literal["School's introduction"], |
|
Literal["Admission information"], |
|
Literal["no_db"] |
|
] = Field( |
|
..., |
|
description="The name of the datasource.", |
|
) |
|
|
|
def pretty_print(self) -> None: |
|
for field in self.__fields__: |
|
value = getattr(self, field) |
|
if value is not None and value != self.__fields__[field].default: |
|
print(f"{field}: {value}") |
|
|
|
def map_db(self) -> Optional[str]: |
|
"""Maps db_name to its corresponding code.""" |
|
db_mapping = { |
|
"Student's book": 'stsv', |
|
"School's introduction": 'gthv', |
|
"Admission information": 'ttts', |
|
"no_db": None |
|
} |
|
mapped_value = db_mapping.get(self.db_name) |
|
if mapped_value: |
|
print(f"Mapped db_name: {mapped_value}") |
|
else: |
|
print("db_name is not recognized or not provided.") |
|
return mapped_value |
|
|
|
class DB_Router: |
|
"""Handles routing queries to the appropriate database using a structured LLM.""" |
|
|
|
def __init__(self): |
|
self.system_prompt = """ |
|
You are an expert at converting user questions into database queries. |
|
Given a question in Vietnamese, your task is to select the most relevant database from the list below. |
|
Think step by step: |
|
0. If the question not in Vietnamese language, return `no_db` |
|
1. Extract the query to find out what the topic of the query is related to. |
|
2. Define which database is most relevant based on the extracted topic. |
|
Match your database to the following databases: |
|
- "Student's book": Cơ sở dữ liệu này bao gồm các hướng dẫn, quy chế, quy định mà sinh viên khi tham gia học tập tại trường sẽ phải thực hiện |
|
- "School's introduction": Cơ sở dữ liệu này giới thiệu tổng quan và toàn diện về trường học |
|
- "Admission information": Cơ sở dữ liệu này chứa các thông tin liên quan đến tuyển sinh qua các năm học |
|
- "no_db": Return "no_db" if none of the databases are relevant or if the question does not match any database content. |
|
|
|
Your response must be strictly one of the following database names and nothing else: |
|
- "Student's book" |
|
- "School's introduction" |
|
- "Admission information" |
|
- "no_db" |
|
|
|
If you do not understand the question, or if none of the databases are relevant, return "no_db". |
|
Do not add any extra words, explanations, or formatting. Just return the name of the database exactly as listed. |
|
""" |
|
self.prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
("system", self.system_prompt), |
|
("human", "{question}"), |
|
] |
|
) |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
api_key = os.getenv('llm_api_9') |
|
self.llm = ChatGroq( |
|
model_name="llama3-groq-70b-8192-tool-use-preview", |
|
temperature=0, |
|
api_key=api_key |
|
) |
|
self.structured_llm = self.llm.with_structured_output(MultiDataSourceSearch) |
|
|
|
def route_query(self, question: str) -> MultiDataSourceSearch: |
|
"""Processes the question and routes it to the appropriate database.""" |
|
query_analyzer = self.prompt | self.structured_llm |
|
result = query_analyzer.invoke(question) |
|
return result |
|
|
|
|
|
if __name__ == "__main__": |
|
db_router = DB_Router() |
|
result = db_router.route_query("tôi được 450 TOIEC thì đổi được bao nhiêu điểm trong trường ?") |
|
print(result) |
|
print(result.map_db()) |