from prompts import get_classification_prompt, get_query_generation_prompt from utils_code import initialize_openai_creds, create_llm from llama_index.core.schema import QueryBundle, NodeWithScore from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever from transformers import pipeline from typing import List, Optional import asyncio from llama_index.core.postprocessor import SentenceTransformerRerank from llama_index.core.indices.property_graph import LLMSynonymRetriever from llama_index.core.indices.property_graph import VectorContextRetriever, PGRetriever from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever, KGTableRetriever import os class PARetriever(BaseRetriever): """Custom retriever that performs query rewriting, Vector search, and BM25 search without Knowledge Graph search.""" def __init__( self, llm, # LLM for query generation vector_retriever: Optional[VectorIndexRetriever] = None, bm25_retriever: Optional[BaseRetriever] = None, mode: str = "OR", rewriter: bool = True, classifier_model: Optional[str] = None, # Optional classifier model device: str = 'cpu', # Device to CPU for huggingface demo reranker_model_name: Optional[str] = None, # Model name for SentenceTransformerRerank verbose: bool = False, # Verbose flag fixed_params: Optional[dict] = None, # New parameter to pass in fixed parameters categories_list: Optional[List[str]] = None, # List of categories for query classification param_mappings: Optional[dict] = None # Custom parameter mappings based on classifier labels ) -> None: """Initialize PARetriever parameters.""" self._vector_retriever = vector_retriever self._bm25_retriever = bm25_retriever self._llm = llm self._rewriter = rewriter self._mode = mode self._reranker_model_name = reranker_model_name self._reranker = None # Initialize reranker as None self.verbose = verbose self.fixed_params = fixed_params self.categories_list = categories_list self.param_mappings = param_mappings or { "label_0": {"top_k": 5, "max_keywords_per_query": 3, "max_knowledge_sequence": 1}, "label_1": {"top_k": 7, "max_keywords_per_query": 4, "max_knowledge_sequence": 2}, "label_2": {"top_k": 10, "max_keywords_per_query": 5, "max_knowledge_sequence": 3} } # Initialize the classifier if provided self.classifier = None if classifier_model: self.classifier = pipeline("text-classification", model=classifier_model, device=device) if mode not in ("AND", "OR"): raise ValueError("Invalid mode.") def classify_query_and_get_params(self, query: str) -> (str, dict): """Classify the query and determine adaptive parameters or use fixed parameters.""" if self.fixed_params: # Use fixed parameters from the dictionary if provided params = self.fixed_params classification_result = "Fixed" if self.verbose: print(f"Using fixed parameters: {params}") else: params = { "top_k": 5, # Default top-k "max_keywords_per_query": 4, # Default max keywords "max_knowledge_sequence": 2 # Default max knowledge sequence } classification_result = None if self.classifier: classification = self.classifier(query)[0] label = classification['label'] # Get the classification label directly classification_result = label # Store the classification result if self.verbose: print(f"Query Classification: {classification['label']} with score {classification['score']}") # Use custom mappings or default mappings if label in self.param_mappings: params = self.param_mappings[label] else: if self.verbose: print(f"Warning: No mapping found for label {label}, using default parameters.") self._classification_result = classification_result return classification_result, params def classify_query(self, query_str: str) -> Optional[str]: """Classify the query into one of the predefined categories using LLM, or skip if no categories are provided.""" if not self.categories_list: if self.verbose: print("No categories provided, skipping query classification.") return None # Generate the classification prompt using external function classification_prompt = get_classification_prompt(self.categories_list) + f" Query: '{query_str}'" response = self._llm.complete(classification_prompt) category = response.text.strip() # Return the category only if it's in the categories list return category if category in self.categories_list else None def generate_queries(self, query_str: str, category: Optional[str], num_queries: int = 3) -> List[str]: """Generate query variations using the LLM, taking into account the category if applicable.""" # Generate query generation prompt using external function query_gen_prompt = get_query_generation_prompt(query_str, num_queries) response = self._llm.complete(query_gen_prompt) queries = response.text.split("\n") queries = [query.strip() for query in queries if query.strip()] if category: category_query = f"{category}" queries.append(category_query) return queries async def run_queries(self, queries: List[str], retrievers: List[BaseRetriever]) -> dict: """Run queries against retrievers.""" tasks = [] for query in queries: for retriever in retrievers: tasks.append(retriever.aretrieve(query)) task_results = await asyncio.gather(*tasks) results_dict = {} for i, (query, query_result) in enumerate(zip(queries, task_results)): results_dict[(query, i)] = query_result return results_dict def fuse_vector_and_bm25_results(self, results_dict, similarity_top_k: int) -> List[NodeWithScore]: """Fuse results from Vector and BM25 retrievers.""" k = 60.0 # `k` is a parameter used to control the impact of outlier rankings. fused_scores = {} text_to_node = {} for nodes_with_scores in results_dict.values(): for rank, node_with_score in enumerate( sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True) ): text = node_with_score.node.get_content() text_to_node[text] = node_with_score if text not in fused_scores: fused_scores[text] = 0.0 fused_scores[text] += 1.0 / (rank + k) reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)) reranked_nodes: List[NodeWithScore] = [] for text, score in reranked_results.items(): if text in text_to_node: node = text_to_node[text] node.score = score reranked_nodes.append(node) else: if self.verbose: print(f"Warning: Text not found in `text_to_node`: {text}") return reranked_nodes[:similarity_top_k] def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve nodes given query.""" if self._rewriter: category = self.classify_query(query_bundle.query_str) if self.verbose and category: print(f"Classified Category: {category}") classification_result, params = self.classify_query_and_get_params(query_bundle.query_str) self._classification_result = classification_result top_k = params["top_k"] if self._reranker_model_name: self._reranker = SentenceTransformerRerank(model=self._reranker_model_name, top_n=top_k) if self.verbose: print(f"Initialized reranker with top_n: {top_k}") num_queries = 3 if top_k == 5 else 5 if top_k == 7 else 7 if self.verbose: print(f"Number of Query Rewrites: {num_queries}") if self._rewriter: queries = self.generate_queries(query_bundle.query_str, category, num_queries=num_queries) if self.verbose: print(f"Generated Queries: {queries}") else: queries = [query_bundle.query_str] active_retrievers = [] if self._vector_retriever: active_retrievers.append(self._vector_retriever) if self._bm25_retriever: active_retrievers.append(self._bm25_retriever) if not active_retrievers: raise ValueError("No active retriever provided!") results = {} if active_retrievers: results = asyncio.run(self.run_queries(queries, active_retrievers)) if self.verbose: print(f"Fusion Results: {results}") final_results = self.fuse_vector_and_bm25_results(results, similarity_top_k=top_k) if self._reranker: final_results = self._reranker.postprocess_nodes(final_results, query_bundle) if self.verbose: print(f"Reranked Results: {final_results}") else: final_results = final_results[:top_k] if self._rewriter: unique_nodes = {} for node in final_results: content = node.node.get_content() if content not in unique_nodes: unique_nodes[content] = node final_results = list(unique_nodes.values()) if self.verbose: print(f"Final Results: {final_results}") return final_results def get_classification_result(self) -> str: return getattr(self, "_classification_result", None) class HyPARetriever(PARetriever): """Custom retriever that extends PARetriever with knowledge graph (KG) search.""" def __init__( self, llm, # LLM for query generation vector_retriever: Optional[VectorIndexRetriever] = None, bm25_retriever: Optional[BaseRetriever] = None, kg_index=None, # Pass the knowledge graph index property_index: bool = True, # Whether to use the property graph for retrieval pg_filters=None, **kwargs, # Pass any additional arguments to PARetriever ): # Initialize PARetriever to reuse all its functionality super().__init__( llm=llm, vector_retriever=vector_retriever, bm25_retriever=bm25_retriever, **kwargs ) # Initialize knowledge graph (KG) specific components self._pg_filters = pg_filters self._kg_index = kg_index self.property_index = property_index def _initialize_kg_retriever(self, params): """Initialize the KG retriever based on retrieval mode.""" graph_index = self._kg_index filters = self._pg_filters if self._kg_index and not self.property_index: # If not using property index, use KGTableRetriever return KGTableRetriever( index=self._kg_index, retriever_mode='hybrid', max_keywords_per_query=params["max_keywords_per_query"], max_knowledge_sequence=params["max_knowledge_sequence"] ) elif self._kg_index and self.property_index: # If using property index, use the simpler graph index retriever # Use this for the DEMO vector_retriever = VectorContextRetriever( graph_store=graph_index.property_graph_store, similarity_top_k=params["max_keywords_per_query"], path_depth=params["max_knowledge_sequence"], include_text=True, filters=filters ) synonym_retriever = LLMSynonymRetriever( graph_store=graph_index.property_graph_store, llm=self._llm, include_text=True, filters=filters ) return graph_index.as_retriever(sub_retrievers=[vector_retriever, synonym_retriever]) #return graph_index.as_retriever(similarity_top_k=params["top_k"]) return None def _combine_with_kg_results(self, vector_bm25_results, kg_results): """Combine KG results with vector and BM25 results.""" vector_ids = {n.node.id_ for n in vector_bm25_results} kg_ids = {n.node.id_ for n in kg_results} combined_results = {n.node.id_: n for n in vector_bm25_results} combined_results.update({n.node.id_: n for n in kg_results}) if self._mode == "AND": result_ids = vector_ids.intersection(kg_ids) else: result_ids = vector_ids.union(kg_ids) return [combined_results[rid] for rid in result_ids] def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve nodes with KG integration.""" # Call PARetriever's _retrieve to get the vector and BM25 results final_results = super()._retrieve(query_bundle) # If we have a KG index, initialize the retriever if self._kg_index: kg_retriever = self._initialize_kg_retriever(self.classify_query_and_get_params(query_bundle.query_str)[1]) if kg_retriever: kg_nodes = kg_retriever.retrieve(query_bundle) # Only combine KG and vector/BM25 results if property_index is True if self.property_index: final_results = self._combine_with_kg_results(final_results, kg_nodes) return final_results import os from dotenv import load_dotenv from llama_index.llms.azure_openai import AzureOpenAI from llama_index.core import VectorStoreIndex, Settings from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.core.node_parser import SentenceSplitter from llama_index.core.retrievers import KGTableRetriever, VectorIndexRetriever from llama_index.retrievers.bm25 import BM25Retriever from llama_index.readers.file import PyMuPDFReader from llama_index.core.chat_engine import ContextChatEngine from llama_index.core.memory.chat_memory_buffer import ChatMemoryBuffer from llama_index.core import KnowledgeGraphIndex from retrievers import PARetriever, HyPARetriever def load_documents(): """Load and return documents from specified file paths.""" loader = PyMuPDFReader() documents1 = loader.load(file_path="../../legal_data/LL144/LL144.pdf") documents2 = loader.load(file_path="../../legal_data/LL144/LL144_Definitions.pdf") return documents1 + documents2 def create_indices(documents, llm, embed_model): """Create and return VectorStoreIndex and KnowledgeGraphIndex from documents.""" splitter = SentenceSplitter(chunk_size=512) vector_index = VectorStoreIndex.from_documents( documents, embed_model=embed_model, transformations=[splitter] ) """graph_index = KnowledgeGraphIndex.from_documents( documents, max_triplets_per_chunk=5, llm=llm, embed_model=embed_model, include_embeddings=True, transformations=[splitter] )""" return vector_index#, graph_index def create_retrievers(vector_index, graph_index, llm, category_list): """Create and return the PA and HyPA retrievers.""" vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=10) bm25_retriever = BM25Retriever.from_defaults(index=vector_index, similarity_top_k=10) PA_retriever = PARetriever( llm=llm, categories_list=category_list, rewriter=True, vector_retriever=vector_retriever, bm25_retriever=bm25_retriever, classifier_model="rk68/distilbert-q-classifier-3", verbose=False ) HyPA_retriever = HyPARetriever( llm=llm, categories_list=category_list, rewriter=True, kg_index=graph_index, vector_retriever=vector_retriever, bm25_retriever=bm25_retriever, classifier_model="rk68/distilbert-q-classifier-3", verbose=False, property_index=False ) return PA_retriever, HyPA_retriever def create_chat_engine(retriever, memory): """Create and return the ContextChatEngine using the provided retriever and memory.""" return ContextChatEngine.from_defaults( retriever=retriever, verbose=False, chat_mode="context", memory_cls=memory, memory=memory ) def main(): # Initialize environment and LLM gpt35_creds, gpt4o_mini_creds, gpt4o_creds = initialize_openai_creds() llm_gpt35 = create_llm(gpt35_creds=gpt35_creds, gpt4o_mini_creds=gpt4o_mini_creds, gpt4o_creds=gpt4o_creds) # Set global settings for embedding model and LLM embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5") Settings.embed_model = embed_model Settings.llm = llm_gpt35 category_list = [ '5-301 Bias Audit', '5-302 Data Requirements', '§ 5-303 Published Results', '§ 5-304 Notice to Candidates and Employees' ] # Load documents and create indices documents = load_documents() vector_index, graph_index = create_indices(documents, llm_gpt35, embed_model) # Create retrievers PA_retriever, HyPA_retriever = create_retrievers(vector_index, graph_index, llm_gpt35, category_list) # Initialize chat memory memory = ChatMemoryBuffer.from_defaults(token_limit=8192) # Create chat engines PA_chat_engine = create_chat_engine(PA_retriever, memory) HyPA_chat_engine = create_chat_engine(HyPA_retriever, memory) # Sample question and response question = "What is a bias audit?" PA_response = PA_chat_engine.chat(question) HyPA_response = HyPA_chat_engine.chat(question) # Output responses in a nicely formatted manner print("\n" + "="*50) print(f"Question: {question}") print("="*50) print("\n------- PA Retriever Response -------") print(PA_response) print("\n------- HyPA Retriever Response -------") print(HyPA_response) print("="*50 + "\n") if __name__ == '__main__': main()