Spaces:
Paused
Paused
##################################################### | |
### DOCUMENT PROCESSOR [FULLDOC] | |
##################################################### | |
### Jonathan Wang | |
# ABOUT: | |
# This creates an app to chat with PDFs. | |
# This is the FULLDOC | |
# which is a class that associates documents | |
# with their critical information | |
# and their tools. (keywords, summary, queryengine, etc.) | |
##################################################### | |
### TODO Board: | |
# Automatically determine which reader to use for each document based on the file type. | |
##################################################### | |
### PROGRAM SETTINGS | |
##################################################### | |
### PROGRAM IMPORTS | |
from __future__ import annotations | |
import asyncio | |
from pathlib import Path | |
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar | |
from uuid import UUID, uuid4 | |
from llama_index.core import StorageContext, VectorStoreIndex | |
from llama_index.core.query_engine import SubQuestionQueryEngine | |
from llama_index.core.schema import BaseNode, TransformComponent | |
from llama_index.core.settings import Settings | |
from llama_index.core.tools import QueryEngineTool, ToolMetadata | |
from streamlit import session_state as ss | |
if TYPE_CHECKING: | |
from llama_index.core.base.base_query_engine import BaseQueryEngine | |
from llama_index.core.callbacks import CallbackManager | |
from llama_index.core.node_parser import NodeParser | |
from llama_index.core.readers.base import BaseReader | |
from llama_index.core.response_synthesizers import BaseSynthesizer | |
from llama_index.core.retrievers import BaseRetriever | |
# Own Modules | |
from engine import get_engine | |
from keywords import KeywordMetadataAdder | |
from retriever import get_retriever | |
from storage import get_docstore, get_vector_store | |
from summary import DEFAULT_ONELINE_SUMMARY_TEMPLATE, DEFAULT_TREE_SUMMARY_TEMPLATE | |
##################################################### | |
### SCRIPT | |
GenericNode = TypeVar("GenericNode", bound=BaseNode) | |
class FullDocument: | |
"""Bundles all the information about a document together. | |
Args: | |
name (str): The name of the document. | |
file_path (Path): The path to the document. | |
summary (str): The summary of the document. | |
keywords (List[str]): The keywords of the document. | |
entities (List[str]): The entities of the document. | |
vector_store (BaseDocumentStore): The vector store of the document. | |
""" | |
# Identifiers | |
id: UUID | |
name: str | |
file_path: Path | |
file_name: str | |
# Basic Contents | |
summary: str | |
summary_oneline: str # A one line summary of the document. | |
keywords: set[str] # List of keywords in document. | |
# entities: Set[str] # list of entities in document ## TODO: Add entities | |
metadata: dict[str, Any] | None | |
# NOTE: other metdata that might be useful: | |
# Document Creation / Last Date (e.g., recency important for legal/medical questions) | |
# Document Source and Trustworthiness | |
# Document Access Level (though this isn't important for us here.) | |
# Document Citations? | |
# Document Format? (text/spreadsheet/presentation/image/etc.) | |
# RAG Components | |
nodes: list[BaseNode] | |
storage_context: StorageContext # NOTE: current setup has single storage context per document. | |
vector_store_index: VectorStoreIndex | |
retriever: BaseRetriever # TODO(Jonathan Wang): Consider multiple retrievers for keywords vs semantic. | |
engine: BaseQueryEngine # TODO(Jonathan Wang): Consider mulitple engines. | |
subquestion_engine: SubQuestionQueryEngine | |
def __init__( | |
self, | |
name: str, | |
file_path: Path | str, | |
metadata: dict[str, Any] | None = None | |
) -> None: | |
self.id = uuid4() | |
self.name = name | |
if (isinstance(file_path, str)): | |
file_path = Path(file_path) | |
self.file_path = file_path | |
self.file_name = file_path.name | |
self.metadata = metadata | |
def class_name(cls) -> str: | |
return "FullDocument" | |
def add_name_to_nodes(self, nodes: list[GenericNode]) -> list[GenericNode]: | |
"""Add the name of the document to the nodes. | |
Args: | |
nodes (List[GenericNode]): The nodes to add the name to. | |
Returns: | |
List[GenericNode]: The nodes with the name added. | |
""" | |
for node in nodes: | |
node.metadata["name"] = self.name | |
return nodes | |
def file_to_nodes( | |
self, | |
reader: BaseReader, | |
postreaders: list[Callable[[list[GenericNode]], list[GenericNode]] | TransformComponent] | None=None, # NOTE: these should be used in order. and probably all TransformComponent instead. | |
node_parser: NodeParser | None=None, | |
postparsers: list[Callable[[list[GenericNode]], list[GenericNode]] | TransformComponent] | None=None, # Stuff like chunking, adding Embeddings, etc. | |
) -> None: | |
"""Read in the file path and get the nodes. | |
Args: | |
file_path (Optional[Path], optional): The path to the file. Defaults to file_path from init. | |
reader (Optional[BaseReader], optional): The reader to use. Defaults to reader from init. | |
""" | |
# Use the provided reader to read in the file. | |
print("NEWPDF: Reading input file...") | |
nodes = reader.load_data(file_path=self.file_path) | |
# Use node postreaders to post process the nodes. | |
if (postreaders is not None): | |
for node_postreader in postreaders: | |
nodes = node_postreader(nodes) # type: ignore (TransformComponent allows a list of nodes) | |
# Use node parser to parse the nodes. | |
if (node_parser is None): | |
node_parser = Settings.node_parser | |
nodes = node_parser(nodes) # type: ignore (Document is a child of BaseNode) | |
# Use node postreaders to post process the nodes. (also add the common name to the nodes) | |
if (postparsers is None): | |
postparsers = [self.add_name_to_nodes] | |
else: | |
postparsers.append(self.add_name_to_nodes) | |
for node_postparser in postparsers: | |
nodes = node_postparser(nodes) # type: ignore (TransformComponent allows a list of nodes) | |
# Save nodes | |
self.nodes = nodes # type: ignore | |
def nodes_to_summary( | |
self, | |
summarizer: BaseSynthesizer, # NOTE: this is typically going to be a TreeSummarizer / SimpleSummarize for our use case | |
query_str: str = DEFAULT_TREE_SUMMARY_TEMPLATE, | |
) -> None: | |
"""Summarize the nodes. | |
Args: | |
summarizer (BaseSynthesizer): The summarizer to use. Takes in nodes and returns summary. | |
""" | |
if (not hasattr(self, "nodes")): | |
msg = "Nodes must be extracted from document using `file_to_nodes` before calling `nodes_to_summary`." | |
raise ValueError(msg) | |
text_chunks = [getattr(node, "text", "") for node in self.nodes if hasattr(node, "text")] | |
summary_responses = summarizer.aget_response(query_str=query_str, text_chunks=text_chunks) | |
loop = asyncio.get_event_loop() | |
summary = loop.run_until_complete(summary_responses) | |
if (not isinstance(summary, str)): | |
# TODO(Jonathan Wang): ... this should always give us a string, right? we're not doing anything fancy with TokenGen/TokenAsyncGen/Pydantic BaseModel... | |
msg = f"Summarizer must return a string summary. Actual type: {type(summary)}, with value {summary}." | |
raise TypeError(msg) | |
self.summary = summary | |
def summary_to_oneline( | |
self, | |
summarizer: BaseSynthesizer, # NOTE: this is typically going to be a SimpleSummarize / TreeSummarizer for our use case | |
query_str: str = DEFAULT_ONELINE_SUMMARY_TEMPLATE, | |
) -> None: | |
if (not hasattr(self, "summary")): | |
msg = "Summary must be extracted from document using `nodes_to_summary` before calling `summary_to_oneline`." | |
raise ValueError(msg) | |
oneline = summarizer.get_response(query_str=query_str, text_chunks=[self.summary]) # There's only one chunk. | |
self.summary_oneline = oneline # type: ignore | shouldn't have fancy TokenGenerators / TokenAsyncGenerators / Pydantic BaseModels | |
def nodes_to_document_keywords(self, keyword_extractor: Optional[KeywordMetadataAdder] = None) -> None: | |
"""Save the keywords from the nodes into the document. | |
Args: | |
keyword_extractor (Optional[BaseKeywordExtractor], optional): The keyword extractor to use. Defaults to None. | |
""" | |
if (not hasattr(self, "nodes")): | |
msg = "Nodes must be extracted from document using `file_to_nodes` before calling `nodes_to_keywords`." | |
raise ValueError(msg) | |
if (keyword_extractor is None): | |
keyword_extractor = KeywordMetadataAdder() | |
# Add keywords to nodes using KeywordMetadataAdder | |
keyword_extractor.process_nodes(self.nodes) | |
# Save keywords | |
keywords: list[str] = [] | |
for node in self.nodes: | |
node_keywords = node.metadata.get("keyword_metadata", "").split(", ") # NOTE: KeywordMetadataAdder concatinates b/c required string output | |
keywords = keywords + node_keywords | |
# TODO(Jonathan Wang): handle dedupling keywords which are similar to each other (fuzzy?) | |
self.keywords = set(keywords) | |
def nodes_to_storage(self, create_new_storage: bool = True) -> None: | |
"""Save the nodes to storage.""" | |
if (not hasattr(self, "nodes")): | |
msg = "Nodes must be extracted from document using `file_to_nodes` before calling `nodes_to_storage`." | |
raise ValueError(msg) | |
if (create_new_storage): | |
docstore = get_docstore(documents=self.nodes) | |
self.docstore = docstore | |
vector_store = get_vector_store() | |
storage_context = StorageContext.from_defaults( | |
docstore=docstore, | |
vector_store=vector_store | |
) | |
self.storage_context = storage_context | |
vector_store_index = VectorStoreIndex( | |
self.nodes, storage_context=storage_context | |
) | |
self.vector_store_index = vector_store_index | |
else: | |
### TODO(Jonathan Wang): use an existing storage instead of creating a new one. | |
msg = "Currently creates new storage for every document." | |
raise NotImplementedError(msg) | |
# TODO(Jonathan Wang): Create multiple different retrievers based on the question type(?) | |
# E.g., if the question is focused on specific keywords or phrases, use a retriever oriented towards sparse scores. | |
def storage_to_retriever( | |
self, | |
semantic_nodes: int = 6, | |
sparse_nodes: int = 3, | |
fusion_nodes: int = 3, | |
semantic_weight: float = 0.6, | |
merge_up_thresh: float = 0.5, | |
callback_manager: CallbackManager | None=None | |
) -> None: | |
"""Create retriever from storage.""" | |
if (not hasattr(self, "vector_store_index")): | |
msg = "Vector store must be extracted from document using `nodes_to_storage` before calling `storage_to_retriever`." | |
raise ValueError(msg) | |
retriever = get_retriever( | |
_vector_store_index=self.vector_store_index, | |
semantic_top_k=semantic_nodes, | |
sparse_top_k=sparse_nodes, | |
fusion_similarity_top_k=fusion_nodes, | |
semantic_weight_fraction=semantic_weight, | |
merge_up_thresh=merge_up_thresh, | |
verbose=True, | |
_callback_manager=callback_manager or ss.callback_manager | |
) | |
self.retriever = retriever | |
def retriever_to_engine( | |
self, | |
response_synthesizer: BaseSynthesizer, | |
callback_manager: CallbackManager | None=None | |
) -> None: | |
"""Create query engine from retriever.""" | |
if (not hasattr(self, "retriever")): | |
msg = "Retriever must be extracted from document using `storage_to_retriever` before calling `retriver_to_engine`." | |
raise ValueError(msg) | |
engine = get_engine( | |
retriever=self.retriever, | |
response_synthesizer=response_synthesizer, | |
callback_manager=callback_manager or ss.callback_manager | |
) | |
self.engine = engine | |
# TODO(Jonathan Wang): Create Summarization Index and Engine. | |
def engine_to_sub_question_engine(self) -> None: | |
"""Convert a basic query engine into a sub-question query engine for handling complex, multi-step questions. | |
Args: | |
query_engine (BaseQueryEngine): The Base Query Engine to convert. | |
""" | |
if (not hasattr(self, "summary_oneline")): | |
msg = "One Line Summary must be created for the document before calling `engine_to_sub_query_engine`" | |
raise ValueError(msg) | |
elif (not hasattr(self, "engine")): | |
msg = "Basic Query Engine must be created before calling `engine_to_sub_query_engine`" | |
raise ValueError(msg) | |
sqe_tools = [ | |
QueryEngineTool( | |
query_engine=self.engine, # TODO(Jonathan Wang): handle mulitple engines? | |
metadata=ToolMetadata( | |
name=(self.name + "simple query answerer"), | |
description=f"""A tool that answers simple questions about the following document: {self.summary_oneline}""" | |
) | |
) | |
# TODO(Jonathan Wang): add more tools | |
] | |
subquestion_engine = SubQuestionQueryEngine.from_defaults( | |
query_engine_tools=sqe_tools, | |
verbose=True, | |
use_async=True | |
) | |
self.subquestion_engine = subquestion_engine | |