|
from prompts import get_classification_prompt, get_query_generation_prompt |
|
from utils_code import initialize_openai_creds, create_llm |
|
from llama_index.core.schema import QueryBundle, NodeWithScore |
|
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever |
|
from transformers import pipeline |
|
from typing import List, Optional |
|
import asyncio |
|
from llama_index.core.postprocessor import SentenceTransformerRerank |
|
from llama_index.core.indices.property_graph import LLMSynonymRetriever |
|
from llama_index.core.indices.property_graph import VectorContextRetriever, PGRetriever |
|
from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever, KGTableRetriever |
|
import os |
|
|
|
|
|
class PARetriever(BaseRetriever): |
|
"""Custom retriever that performs query rewriting, Vector search, and BM25 search without Knowledge Graph search.""" |
|
|
|
def __init__( |
|
self, |
|
llm, |
|
vector_retriever: Optional[VectorIndexRetriever] = None, |
|
bm25_retriever: Optional[BaseRetriever] = None, |
|
mode: str = "OR", |
|
rewriter: bool = True, |
|
classifier_model: Optional[str] = None, |
|
device: str = 'cpu', |
|
reranker_model_name: Optional[str] = None, |
|
verbose: bool = False, |
|
fixed_params: Optional[dict] = None, |
|
categories_list: Optional[List[str]] = None, |
|
param_mappings: Optional[dict] = None |
|
) -> None: |
|
"""Initialize PARetriever parameters.""" |
|
self._vector_retriever = vector_retriever |
|
self._bm25_retriever = bm25_retriever |
|
self._llm = llm |
|
self._rewriter = rewriter |
|
self._mode = mode |
|
self._reranker_model_name = reranker_model_name |
|
self._reranker = None |
|
self.verbose = verbose |
|
self.fixed_params = fixed_params |
|
self.categories_list = categories_list |
|
self.param_mappings = param_mappings or { |
|
"label_0": {"top_k": 5, "max_keywords_per_query": 3, "max_knowledge_sequence": 1}, |
|
"label_1": {"top_k": 7, "max_keywords_per_query": 4, "max_knowledge_sequence": 2}, |
|
"label_2": {"top_k": 10, "max_keywords_per_query": 5, "max_knowledge_sequence": 3} |
|
} |
|
|
|
|
|
self.classifier = None |
|
if classifier_model: |
|
self.classifier = pipeline("text-classification", model=classifier_model, device=device) |
|
|
|
if mode not in ("AND", "OR"): |
|
raise ValueError("Invalid mode.") |
|
|
|
def classify_query_and_get_params(self, query: str) -> (str, dict): |
|
"""Classify the query and determine adaptive parameters or use fixed parameters.""" |
|
if self.fixed_params: |
|
|
|
params = self.fixed_params |
|
classification_result = "Fixed" |
|
if self.verbose: |
|
print(f"Using fixed parameters: {params}") |
|
else: |
|
params = { |
|
"top_k": 5, |
|
"max_keywords_per_query": 4, |
|
"max_knowledge_sequence": 2 |
|
} |
|
classification_result = None |
|
|
|
if self.classifier: |
|
classification = self.classifier(query)[0] |
|
label = classification['label'] |
|
classification_result = label |
|
if self.verbose: |
|
print(f"Query Classification: {classification['label']} with score {classification['score']}") |
|
|
|
|
|
if label in self.param_mappings: |
|
params = self.param_mappings[label] |
|
else: |
|
if self.verbose: |
|
print(f"Warning: No mapping found for label {label}, using default parameters.") |
|
|
|
self._classification_result = classification_result |
|
return classification_result, params |
|
|
|
def classify_query(self, query_str: str) -> Optional[str]: |
|
"""Classify the query into one of the predefined categories using LLM, or skip if no categories are provided.""" |
|
if not self.categories_list: |
|
if self.verbose: |
|
print("No categories provided, skipping query classification.") |
|
return None |
|
|
|
|
|
classification_prompt = get_classification_prompt(self.categories_list) + f" Query: '{query_str}'" |
|
|
|
response = self._llm.complete(classification_prompt) |
|
category = response.text.strip() |
|
|
|
|
|
return category if category in self.categories_list else None |
|
|
|
def generate_queries(self, query_str: str, category: Optional[str], num_queries: int = 3) -> List[str]: |
|
"""Generate query variations using the LLM, taking into account the category if applicable.""" |
|
|
|
|
|
query_gen_prompt = get_query_generation_prompt(query_str, num_queries) |
|
|
|
response = self._llm.complete(query_gen_prompt) |
|
queries = response.text.split("\n") |
|
|
|
queries = [query.strip() for query in queries if query.strip()] |
|
|
|
if category: |
|
category_query = f"{category}" |
|
queries.append(category_query) |
|
|
|
return queries |
|
|
|
async def run_queries(self, queries: List[str], retrievers: List[BaseRetriever]) -> dict: |
|
"""Run queries against retrievers.""" |
|
tasks = [] |
|
for query in queries: |
|
for retriever in retrievers: |
|
tasks.append(retriever.aretrieve(query)) |
|
|
|
task_results = await asyncio.gather(*tasks) |
|
|
|
results_dict = {} |
|
for i, (query, query_result) in enumerate(zip(queries, task_results)): |
|
results_dict[(query, i)] = query_result |
|
return results_dict |
|
|
|
def fuse_vector_and_bm25_results(self, results_dict, similarity_top_k: int) -> List[NodeWithScore]: |
|
"""Fuse results from Vector and BM25 retrievers.""" |
|
k = 60.0 |
|
fused_scores = {} |
|
text_to_node = {} |
|
|
|
for nodes_with_scores in results_dict.values(): |
|
for rank, node_with_score in enumerate( |
|
sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True) |
|
): |
|
text = node_with_score.node.get_content() |
|
text_to_node[text] = node_with_score |
|
if text not in fused_scores: |
|
fused_scores[text] = 0.0 |
|
fused_scores[text] += 1.0 / (rank + k) |
|
|
|
reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)) |
|
|
|
reranked_nodes: List[NodeWithScore] = [] |
|
for text, score in reranked_results.items(): |
|
if text in text_to_node: |
|
node = text_to_node[text] |
|
node.score = score |
|
reranked_nodes.append(node) |
|
else: |
|
if self.verbose: |
|
print(f"Warning: Text not found in `text_to_node`: {text}") |
|
|
|
return reranked_nodes[:similarity_top_k] |
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: |
|
"""Retrieve nodes given query.""" |
|
if self._rewriter: |
|
category = self.classify_query(query_bundle.query_str) |
|
if self.verbose and category: |
|
print(f"Classified Category: {category}") |
|
|
|
classification_result, params = self.classify_query_and_get_params(query_bundle.query_str) |
|
self._classification_result = classification_result |
|
|
|
top_k = params["top_k"] |
|
|
|
if self._reranker_model_name: |
|
self._reranker = SentenceTransformerRerank(model=self._reranker_model_name, top_n=top_k) |
|
if self.verbose: |
|
print(f"Initialized reranker with top_n: {top_k}") |
|
|
|
num_queries = 3 if top_k == 5 else 5 if top_k == 7 else 7 |
|
if self.verbose: |
|
print(f"Number of Query Rewrites: {num_queries}") |
|
|
|
if self._rewriter: |
|
queries = self.generate_queries(query_bundle.query_str, category, num_queries=num_queries) |
|
if self.verbose: |
|
print(f"Generated Queries: {queries}") |
|
else: |
|
queries = [query_bundle.query_str] |
|
|
|
active_retrievers = [] |
|
if self._vector_retriever: |
|
active_retrievers.append(self._vector_retriever) |
|
if self._bm25_retriever: |
|
active_retrievers.append(self._bm25_retriever) |
|
|
|
if not active_retrievers: |
|
raise ValueError("No active retriever provided!") |
|
|
|
results = {} |
|
if active_retrievers: |
|
results = asyncio.run(self.run_queries(queries, active_retrievers)) |
|
if self.verbose: |
|
print(f"Fusion Results: {results}") |
|
|
|
final_results = self.fuse_vector_and_bm25_results(results, similarity_top_k=top_k) |
|
|
|
if self._reranker: |
|
final_results = self._reranker.postprocess_nodes(final_results, query_bundle) |
|
if self.verbose: |
|
print(f"Reranked Results: {final_results}") |
|
else: |
|
final_results = final_results[:top_k] |
|
|
|
if self._rewriter: |
|
unique_nodes = {} |
|
for node in final_results: |
|
content = node.node.get_content() |
|
if content not in unique_nodes: |
|
unique_nodes[content] = node |
|
final_results = list(unique_nodes.values()) |
|
|
|
if self.verbose: |
|
print(f"Final Results: {final_results}") |
|
|
|
return final_results |
|
|
|
def get_classification_result(self) -> str: |
|
return getattr(self, "_classification_result", None) |
|
|
|
|
|
class HyPARetriever(PARetriever): |
|
"""Custom retriever that extends PARetriever with knowledge graph (KG) search.""" |
|
|
|
def __init__( |
|
self, |
|
llm, |
|
vector_retriever: Optional[VectorIndexRetriever] = None, |
|
bm25_retriever: Optional[BaseRetriever] = None, |
|
kg_index=None, |
|
property_index: bool = True, |
|
pg_filters=None, |
|
**kwargs, |
|
): |
|
|
|
super().__init__( |
|
llm=llm, |
|
vector_retriever=vector_retriever, |
|
bm25_retriever=bm25_retriever, |
|
**kwargs |
|
) |
|
|
|
|
|
self._pg_filters = pg_filters |
|
self._kg_index = kg_index |
|
self.property_index = property_index |
|
|
|
def _initialize_kg_retriever(self, params): |
|
"""Initialize the KG retriever based on retrieval mode.""" |
|
graph_index = self._kg_index |
|
filters = self._pg_filters |
|
|
|
if self._kg_index and not self.property_index: |
|
|
|
return KGTableRetriever( |
|
index=self._kg_index, |
|
retriever_mode='hybrid', |
|
max_keywords_per_query=params["max_keywords_per_query"], |
|
max_knowledge_sequence=params["max_knowledge_sequence"] |
|
) |
|
|
|
elif self._kg_index and self.property_index: |
|
|
|
|
|
|
|
vector_retriever = VectorContextRetriever( |
|
graph_store=graph_index.property_graph_store, |
|
similarity_top_k=params["max_keywords_per_query"], |
|
path_depth=params["max_knowledge_sequence"], |
|
include_text=True, |
|
filters=filters |
|
) |
|
synonym_retriever = LLMSynonymRetriever( |
|
graph_store=graph_index.property_graph_store, |
|
llm=self._llm, |
|
include_text=True, |
|
filters=filters |
|
) |
|
return graph_index.as_retriever(sub_retrievers=[vector_retriever, synonym_retriever]) |
|
|
|
|
|
return None |
|
|
|
def _combine_with_kg_results(self, vector_bm25_results, kg_results): |
|
"""Combine KG results with vector and BM25 results.""" |
|
vector_ids = {n.node.id_ for n in vector_bm25_results} |
|
kg_ids = {n.node.id_ for n in kg_results} |
|
|
|
combined_results = {n.node.id_: n for n in vector_bm25_results} |
|
combined_results.update({n.node.id_: n for n in kg_results}) |
|
|
|
if self._mode == "AND": |
|
result_ids = vector_ids.intersection(kg_ids) |
|
else: |
|
result_ids = vector_ids.union(kg_ids) |
|
|
|
return [combined_results[rid] for rid in result_ids] |
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: |
|
"""Retrieve nodes with KG integration.""" |
|
|
|
final_results = super()._retrieve(query_bundle) |
|
|
|
|
|
if self._kg_index: |
|
kg_retriever = self._initialize_kg_retriever(self.classify_query_and_get_params(query_bundle.query_str)[1]) |
|
|
|
if kg_retriever: |
|
kg_nodes = kg_retriever.retrieve(query_bundle) |
|
|
|
|
|
if self.property_index: |
|
final_results = self._combine_with_kg_results(final_results, kg_nodes) |
|
|
|
return final_results |
|
|
|
|
|
|
|
import os |
|
from dotenv import load_dotenv |
|
from llama_index.llms.azure_openai import AzureOpenAI |
|
from llama_index.core import VectorStoreIndex, Settings |
|
from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
|
from llama_index.core.node_parser import SentenceSplitter |
|
from llama_index.core.retrievers import KGTableRetriever, VectorIndexRetriever |
|
from llama_index.retrievers.bm25 import BM25Retriever |
|
from llama_index.readers.file import PyMuPDFReader |
|
from llama_index.core.chat_engine import ContextChatEngine |
|
from llama_index.core.memory.chat_memory_buffer import ChatMemoryBuffer |
|
from llama_index.core import KnowledgeGraphIndex |
|
from retrievers import PARetriever, HyPARetriever |
|
|
|
|
|
def load_documents(): |
|
"""Load and return documents from specified file paths.""" |
|
loader = PyMuPDFReader() |
|
documents1 = loader.load(file_path="../../legal_data/LL144/LL144.pdf") |
|
documents2 = loader.load(file_path="../../legal_data/LL144/LL144_Definitions.pdf") |
|
return documents1 + documents2 |
|
|
|
def create_indices(documents, llm, embed_model): |
|
"""Create and return VectorStoreIndex and KnowledgeGraphIndex from documents.""" |
|
splitter = SentenceSplitter(chunk_size=512) |
|
|
|
vector_index = VectorStoreIndex.from_documents( |
|
documents, |
|
embed_model=embed_model, |
|
transformations=[splitter] |
|
) |
|
|
|
"""graph_index = KnowledgeGraphIndex.from_documents( |
|
documents, |
|
max_triplets_per_chunk=5, |
|
llm=llm, |
|
embed_model=embed_model, |
|
include_embeddings=True, |
|
transformations=[splitter] |
|
)""" |
|
|
|
return vector_index |
|
|
|
def create_retrievers(vector_index, graph_index, llm, category_list): |
|
"""Create and return the PA and HyPA retrievers.""" |
|
vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=10) |
|
bm25_retriever = BM25Retriever.from_defaults(index=vector_index, similarity_top_k=10) |
|
|
|
PA_retriever = PARetriever( |
|
llm=llm, |
|
categories_list=category_list, |
|
rewriter=True, |
|
vector_retriever=vector_retriever, |
|
bm25_retriever=bm25_retriever, |
|
classifier_model="rk68/distilbert-q-classifier-3", |
|
verbose=False |
|
) |
|
|
|
HyPA_retriever = HyPARetriever( |
|
llm=llm, |
|
categories_list=category_list, |
|
rewriter=True, |
|
kg_index=graph_index, |
|
vector_retriever=vector_retriever, |
|
bm25_retriever=bm25_retriever, |
|
classifier_model="rk68/distilbert-q-classifier-3", |
|
verbose=False, |
|
property_index=False |
|
) |
|
|
|
return PA_retriever, HyPA_retriever |
|
|
|
def create_chat_engine(retriever, memory): |
|
"""Create and return the ContextChatEngine using the provided retriever and memory.""" |
|
return ContextChatEngine.from_defaults( |
|
retriever=retriever, |
|
verbose=False, |
|
chat_mode="context", |
|
memory_cls=memory, |
|
memory=memory |
|
) |
|
|
|
def main(): |
|
|
|
gpt35_creds, gpt4o_mini_creds, gpt4o_creds = initialize_openai_creds() |
|
llm_gpt35 = create_llm(gpt35_creds=gpt35_creds, gpt4o_mini_creds=gpt4o_mini_creds, gpt4o_creds=gpt4o_creds) |
|
|
|
|
|
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5") |
|
Settings.embed_model = embed_model |
|
Settings.llm = llm_gpt35 |
|
|
|
category_list = [ |
|
'5-301 Bias Audit', |
|
'5-302 Data Requirements', |
|
'§ 5-303 Published Results', |
|
'§ 5-304 Notice to Candidates and Employees' |
|
] |
|
|
|
|
|
documents = load_documents() |
|
vector_index, graph_index = create_indices(documents, llm_gpt35, embed_model) |
|
|
|
|
|
PA_retriever, HyPA_retriever = create_retrievers(vector_index, graph_index, llm_gpt35, category_list) |
|
|
|
|
|
memory = ChatMemoryBuffer.from_defaults(token_limit=8192) |
|
|
|
|
|
PA_chat_engine = create_chat_engine(PA_retriever, memory) |
|
HyPA_chat_engine = create_chat_engine(HyPA_retriever, memory) |
|
|
|
|
|
question = "What is a bias audit?" |
|
PA_response = PA_chat_engine.chat(question) |
|
HyPA_response = HyPA_chat_engine.chat(question) |
|
|
|
|
|
print("\n" + "="*50) |
|
print(f"Question: {question}") |
|
print("="*50) |
|
|
|
print("\n------- PA Retriever Response -------") |
|
print(PA_response) |
|
|
|
print("\n------- HyPA Retriever Response -------") |
|
print(HyPA_response) |
|
print("="*50 + "\n") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|