##################################################### ### DOCUMENT PROCESSOR [RETRIEVER] ##################################################### # Jonathan Wang # ABOUT: # This project creates an app to chat with PDFs. # This is the RETRIEVER # which defines the main way that document # snippets are identified. ##################################################### ## TODO: ##################################################### ## IMPORTS: import logging from typing import Optional, List, Tuple, Dict, cast from collections import defaultdict import streamlit as st import numpy as np from llama_index.core.utils import truncate_text from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever from llama_index.retrievers.bm25 import BM25Retriever from llama_index.core import VectorStoreIndex #, StorageContext, from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle from llama_index.core.callbacks.base import CallbackManager # Own Modules: from merger import _merge_on_scores # Lazy Loading: ##################################################### ## CODE: class RAGRetriever(BaseRetriever): """ Jonathan Wang's custom built retriever over our vector store. Combination of Hybrid Retrieval (BM25 x Vector Embeddings) + AutoMergingRetriever https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/retrievers/auto_merging_retriever.py """ def __init__( self, vector_store_index: VectorStoreIndex, semantic_top_k: int = 10, sparse_top_k: int = 6, fusion_similarity_top_k: int = 10, # total number of snippets to retrieve after the Reicprocal Rerank. semantic_weight_fraction: float = 0.6, # percentage weight to give to semantic cosine vs sparse bm25 merge_up_thresh: float = 0.5, # fraction of nodes needed to be retrieved to merge up to semantic level verbose: bool = True, callback_manager: Optional[CallbackManager] = None, object_map: Optional[dict] = None, objects: Optional[List[IndexNode]] = None, ) -> None: """Init params.""" self._vector_store_index = vector_store_index self.sentence_vector_retriever = VectorIndexRetriever( index=vector_store_index, similarity_top_k=semantic_top_k ) self.sentence_bm25_retriever = BM25Retriever.from_defaults( # nodes=list(vector_store_index.storage_context.docstore.docs.values()) index=vector_store_index # TODO: Confirm this works. , similarity_top_k=sparse_top_k ) self._fusion_similarity_top_k = fusion_similarity_top_k self._semantic_weight_fraction = semantic_weight_fraction self._merge_up_thresh = merge_up_thresh super().__init__( # callback_manager=callback_manager, object_map=object_map, objects=objects, verbose=verbose, ) @classmethod def class_name(cls) -> str: """Class name.""" return "RAGRetriever" def _get_parents_and_merge( self, nodes: List[NodeWithScore] ) -> Tuple[List[NodeWithScore], bool]: """Get parents and merge nodes.""" # retrieve all parent nodes parent_nodes: Dict[str, BaseNode] = {} parent_cur_children_dict: Dict[str, List[NodeWithScore]] = defaultdict(list) for node in nodes: if node.node.parent_node is None: continue parent_node_info = node.node.parent_node # Fetch actual parent node if doesn't exist in `parent_nodes` cache yet parent_node_id = parent_node_info.node_id if parent_node_id not in parent_nodes: parent_node = self._vector_store_index.storage_context.docstore.get_document( parent_node_id ) parent_nodes[parent_node_id] = cast(BaseNode, parent_node) # add reference to child from parent parent_cur_children_dict[parent_node_id].append(node) # compute ratios and "merge" nodes # merging: delete some children nodes, add some parent nodes node_ids_to_delete = set() nodes_to_add: Dict[str, BaseNode] = {} for parent_node_id, parent_node in parent_nodes.items(): parent_child_nodes = parent_node.child_nodes parent_num_children = len(parent_child_nodes) if parent_child_nodes else 1 parent_cur_children = parent_cur_children_dict[parent_node_id] ratio = len(parent_cur_children) / parent_num_children # if ratio is high enough, merge up to the next level in the hierarchy if ratio > self._merge_up_thresh: node_ids_to_delete.update( set({n.node.node_id for n in parent_cur_children}) ) parent_node_text = truncate_text(getattr(parent_node, 'text', ''), 100) info_str = ( f"> Merging {len(parent_cur_children)} nodes into parent node.\n" f"> Parent node id: {parent_node_id}.\n" f"> Parent node text: {parent_node_text}\n" ) # logger.info(info_str) if self._verbose: print(info_str) # add parent node # can try averaging score across embeddings for now avg_score = sum( [n.get_score() or 0.0 for n in parent_cur_children] ) / len(parent_cur_children) parent_node_with_score = NodeWithScore( node=parent_node, score=avg_score ) nodes_to_add[parent_node_id] = parent_node_with_score # type: ignore (NodesWithScore is a child of BaseNode) # delete old child nodes, add new parent nodes new_nodes = [n for n in nodes if n.node.node_id not in node_ids_to_delete] # add parent nodes new_nodes.extend(list(nodes_to_add.values())) # type: ignore (NodesWithScore is a child of BaseNode) is_changed = len(node_ids_to_delete) > 0 return new_nodes, is_changed def _fill_in_nodes( self, nodes: List[NodeWithScore] ) -> Tuple[List[NodeWithScore], bool]: """Fill in nodes.""" new_nodes = [] is_changed = False for idx, node in enumerate(nodes): new_nodes.append(node) if idx >= len(nodes) - 1: continue cur_node = cast(BaseNode, node.node) # if there's a node in the middle, add that to the queue if ( cur_node.next_node is not None and cur_node.next_node == nodes[idx + 1].node.prev_node ): is_changed = True next_node = self._vector_store_index.storage_context.docstore.get_document( cur_node.next_node.node_id ) next_node = cast(BaseNode, next_node) next_node_text = truncate_text(getattr(next_node, 'text', ''), 100) # TODO: why not higher? info_str = ( f"> Filling in node. Node id: {cur_node.next_node.node_id}" f"> Node text: {next_node_text}\n" ) # logger.info(info_str) if self._verbose: print(info_str) # set score to be average of current node and next node avg_score = (node.get_score() + nodes[idx + 1].get_score()) / 2 new_nodes.append(NodeWithScore(node=next_node, score=avg_score)) return new_nodes, is_changed def _try_merging( self, nodes: List[NodeWithScore] ) -> Tuple[List[NodeWithScore], bool]: """Try different ways to merge nodes.""" # first try filling in nodes nodes, is_changed_0 = self._fill_in_nodes(nodes) # then try merging nodes nodes, is_changed_1 = self._get_parents_and_merge(nodes) return nodes, is_changed_0 or is_changed_1 def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve.""" # Get vector stores retrieved nodes vector_sentence_nodes = self.sentence_vector_retriever.retrieve(query_bundle)# , **kwargs) bm25_sentence_nodes = self.sentence_bm25_retriever.retrieve(query_bundle)# , **kwargs) # Get initial nodes from hybrid search. initial_nodes = _merge_on_scores( vector_sentence_nodes, bm25_sentence_nodes, [getattr(a, "score", np.nan) for a in vector_sentence_nodes], [getattr(b, "score", np.nan) for b in bm25_sentence_nodes], a_weight=self._semantic_weight_fraction, top_k=self._fusion_similarity_top_k ) # Merge nodes cur_nodes, is_changed = self._try_merging(list(initial_nodes)) # technically _merge_on_scores returns a sequence. while is_changed: cur_nodes, is_changed = self._try_merging(cur_nodes) # sort by similarity cur_nodes.sort(key=lambda x: x.get_score(), reverse=True) # some other reranking and filtering node postprocessors here? # https://docs.llamaindex.ai/en/stable/module_guides/querying/node_postprocessors/root.html return cur_nodes @st.cache_resource def get_retriever( _vector_store_index: VectorStoreIndex, semantic_top_k: int = 10, sparse_top_k: int = 6, fusion_similarity_top_k: int = 10, # total number of snippets to retrieve after the Reicprocal Rerank. semantic_weight_fraction: float = 0.6, # percentage weight to give to semantic chunks over sentence chunks merge_up_thresh: float = 0.5, # fraction of nodes needed to be retrieved to merge up to semantic level verbose: bool = True, _callback_manager: Optional[CallbackManager] = None, object_map: Optional[dict] = None, objects: Optional[List[IndexNode]] = None, ) -> BaseRetriever: """Get the retriver to use. Args: vector_store_index (VectorStoreIndex): The vector store to query on. semantic_top_k (int, optional): Top k nodes to retrieve semantically (cosine). Defaults to 10. sparse_top_k (int, optional): Top k nodes to retrieve sparsely (BM25). Defaults to 6. fusion_similarity_top_k (int, optional): Maximum number of nodes to retrieve after fusing. Defaults to 10. callback_manager (Optional[CallbackManager], optional): Callback manager. Defaults to None. object_map (Optional[dict], optional): Object map. Defaults to None. objects (Optional[List[IndexNode]], optional): Objects list. Defaults to None. Returns: BaseRetriever: Retriever to use. """ retriever = RAGRetriever( vector_store_index=_vector_store_index, semantic_top_k=semantic_top_k, sparse_top_k=sparse_top_k, fusion_similarity_top_k=fusion_similarity_top_k, semantic_weight_fraction=semantic_weight_fraction, merge_up_thresh=merge_up_thresh, verbose=verbose, callback_manager=_callback_manager, object_map=object_map, objects=objects ) return (retriever)