Autodoc-Lifter / full_doc.py
Jonathan Wang
initial commit
89cbc4d
#####################################################
### DOCUMENT PROCESSOR [FULLDOC]
#####################################################
### Jonathan Wang
# ABOUT:
# This creates an app to chat with PDFs.
# This is the FULLDOC
# which is a class that associates documents
# with their critical information
# and their tools. (keywords, summary, queryengine, etc.)
#####################################################
### TODO Board:
# Automatically determine which reader to use for each document based on the file type.
#####################################################
### PROGRAM SETTINGS
#####################################################
### PROGRAM IMPORTS
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar
from uuid import UUID, uuid4
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.query_engine import SubQuestionQueryEngine
from llama_index.core.schema import BaseNode, TransformComponent
from llama_index.core.settings import Settings
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from streamlit import session_state as ss
if TYPE_CHECKING:
from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.callbacks import CallbackManager
from llama_index.core.node_parser import NodeParser
from llama_index.core.readers.base import BaseReader
from llama_index.core.response_synthesizers import BaseSynthesizer
from llama_index.core.retrievers import BaseRetriever
# Own Modules
from engine import get_engine
from keywords import KeywordMetadataAdder
from retriever import get_retriever
from storage import get_docstore, get_vector_store
from summary import DEFAULT_ONELINE_SUMMARY_TEMPLATE, DEFAULT_TREE_SUMMARY_TEMPLATE
#####################################################
### SCRIPT
GenericNode = TypeVar("GenericNode", bound=BaseNode)
class FullDocument:
"""Bundles all the information about a document together.
Args:
name (str): The name of the document.
file_path (Path): The path to the document.
summary (str): The summary of the document.
keywords (List[str]): The keywords of the document.
entities (List[str]): The entities of the document.
vector_store (BaseDocumentStore): The vector store of the document.
"""
# Identifiers
id: UUID
name: str
file_path: Path
file_name: str
# Basic Contents
summary: str
summary_oneline: str # A one line summary of the document.
keywords: set[str] # List of keywords in document.
# entities: Set[str] # list of entities in document ## TODO: Add entities
metadata: dict[str, Any] | None
# NOTE: other metdata that might be useful:
# Document Creation / Last Date (e.g., recency important for legal/medical questions)
# Document Source and Trustworthiness
# Document Access Level (though this isn't important for us here.)
# Document Citations?
# Document Format? (text/spreadsheet/presentation/image/etc.)
# RAG Components
nodes: list[BaseNode]
storage_context: StorageContext # NOTE: current setup has single storage context per document.
vector_store_index: VectorStoreIndex
retriever: BaseRetriever # TODO(Jonathan Wang): Consider multiple retrievers for keywords vs semantic.
engine: BaseQueryEngine # TODO(Jonathan Wang): Consider mulitple engines.
subquestion_engine: SubQuestionQueryEngine
def __init__(
self,
name: str,
file_path: Path | str,
metadata: dict[str, Any] | None = None
) -> None:
self.id = uuid4()
self.name = name
if (isinstance(file_path, str)):
file_path = Path(file_path)
self.file_path = file_path
self.file_name = file_path.name
self.metadata = metadata
@classmethod
def class_name(cls) -> str:
return "FullDocument"
def add_name_to_nodes(self, nodes: list[GenericNode]) -> list[GenericNode]:
"""Add the name of the document to the nodes.
Args:
nodes (List[GenericNode]): The nodes to add the name to.
Returns:
List[GenericNode]: The nodes with the name added.
"""
for node in nodes:
node.metadata["name"] = self.name
return nodes
def file_to_nodes(
self,
reader: BaseReader,
postreaders: list[Callable[[list[GenericNode]], list[GenericNode]] | TransformComponent] | None=None, # NOTE: these should be used in order. and probably all TransformComponent instead.
node_parser: NodeParser | None=None,
postparsers: list[Callable[[list[GenericNode]], list[GenericNode]] | TransformComponent] | None=None, # Stuff like chunking, adding Embeddings, etc.
) -> None:
"""Read in the file path and get the nodes.
Args:
file_path (Optional[Path], optional): The path to the file. Defaults to file_path from init.
reader (Optional[BaseReader], optional): The reader to use. Defaults to reader from init.
"""
# Use the provided reader to read in the file.
print("NEWPDF: Reading input file...")
nodes = reader.load_data(file_path=self.file_path)
# Use node postreaders to post process the nodes.
if (postreaders is not None):
for node_postreader in postreaders:
nodes = node_postreader(nodes) # type: ignore (TransformComponent allows a list of nodes)
# Use node parser to parse the nodes.
if (node_parser is None):
node_parser = Settings.node_parser
nodes = node_parser(nodes) # type: ignore (Document is a child of BaseNode)
# Use node postreaders to post process the nodes. (also add the common name to the nodes)
if (postparsers is None):
postparsers = [self.add_name_to_nodes]
else:
postparsers.append(self.add_name_to_nodes)
for node_postparser in postparsers:
nodes = node_postparser(nodes) # type: ignore (TransformComponent allows a list of nodes)
# Save nodes
self.nodes = nodes # type: ignore
def nodes_to_summary(
self,
summarizer: BaseSynthesizer, # NOTE: this is typically going to be a TreeSummarizer / SimpleSummarize for our use case
query_str: str = DEFAULT_TREE_SUMMARY_TEMPLATE,
) -> None:
"""Summarize the nodes.
Args:
summarizer (BaseSynthesizer): The summarizer to use. Takes in nodes and returns summary.
"""
if (not hasattr(self, "nodes")):
msg = "Nodes must be extracted from document using `file_to_nodes` before calling `nodes_to_summary`."
raise ValueError(msg)
text_chunks = [getattr(node, "text", "") for node in self.nodes if hasattr(node, "text")]
summary_responses = summarizer.aget_response(query_str=query_str, text_chunks=text_chunks)
loop = asyncio.get_event_loop()
summary = loop.run_until_complete(summary_responses)
if (not isinstance(summary, str)):
# TODO(Jonathan Wang): ... this should always give us a string, right? we're not doing anything fancy with TokenGen/TokenAsyncGen/Pydantic BaseModel...
msg = f"Summarizer must return a string summary. Actual type: {type(summary)}, with value {summary}."
raise TypeError(msg)
self.summary = summary
def summary_to_oneline(
self,
summarizer: BaseSynthesizer, # NOTE: this is typically going to be a SimpleSummarize / TreeSummarizer for our use case
query_str: str = DEFAULT_ONELINE_SUMMARY_TEMPLATE,
) -> None:
if (not hasattr(self, "summary")):
msg = "Summary must be extracted from document using `nodes_to_summary` before calling `summary_to_oneline`."
raise ValueError(msg)
oneline = summarizer.get_response(query_str=query_str, text_chunks=[self.summary]) # There's only one chunk.
self.summary_oneline = oneline # type: ignore | shouldn't have fancy TokenGenerators / TokenAsyncGenerators / Pydantic BaseModels
def nodes_to_document_keywords(self, keyword_extractor: Optional[KeywordMetadataAdder] = None) -> None:
"""Save the keywords from the nodes into the document.
Args:
keyword_extractor (Optional[BaseKeywordExtractor], optional): The keyword extractor to use. Defaults to None.
"""
if (not hasattr(self, "nodes")):
msg = "Nodes must be extracted from document using `file_to_nodes` before calling `nodes_to_keywords`."
raise ValueError(msg)
if (keyword_extractor is None):
keyword_extractor = KeywordMetadataAdder()
# Add keywords to nodes using KeywordMetadataAdder
keyword_extractor.process_nodes(self.nodes)
# Save keywords
keywords: list[str] = []
for node in self.nodes:
node_keywords = node.metadata.get("keyword_metadata", "").split(", ") # NOTE: KeywordMetadataAdder concatinates b/c required string output
keywords = keywords + node_keywords
# TODO(Jonathan Wang): handle dedupling keywords which are similar to each other (fuzzy?)
self.keywords = set(keywords)
def nodes_to_storage(self, create_new_storage: bool = True) -> None:
"""Save the nodes to storage."""
if (not hasattr(self, "nodes")):
msg = "Nodes must be extracted from document using `file_to_nodes` before calling `nodes_to_storage`."
raise ValueError(msg)
if (create_new_storage):
docstore = get_docstore(documents=self.nodes)
self.docstore = docstore
vector_store = get_vector_store()
storage_context = StorageContext.from_defaults(
docstore=docstore,
vector_store=vector_store
)
self.storage_context = storage_context
vector_store_index = VectorStoreIndex(
self.nodes, storage_context=storage_context
)
self.vector_store_index = vector_store_index
else:
### TODO(Jonathan Wang): use an existing storage instead of creating a new one.
msg = "Currently creates new storage for every document."
raise NotImplementedError(msg)
# TODO(Jonathan Wang): Create multiple different retrievers based on the question type(?)
# E.g., if the question is focused on specific keywords or phrases, use a retriever oriented towards sparse scores.
def storage_to_retriever(
self,
semantic_nodes: int = 6,
sparse_nodes: int = 3,
fusion_nodes: int = 3,
semantic_weight: float = 0.6,
merge_up_thresh: float = 0.5,
callback_manager: CallbackManager | None=None
) -> None:
"""Create retriever from storage."""
if (not hasattr(self, "vector_store_index")):
msg = "Vector store must be extracted from document using `nodes_to_storage` before calling `storage_to_retriever`."
raise ValueError(msg)
retriever = get_retriever(
_vector_store_index=self.vector_store_index,
semantic_top_k=semantic_nodes,
sparse_top_k=sparse_nodes,
fusion_similarity_top_k=fusion_nodes,
semantic_weight_fraction=semantic_weight,
merge_up_thresh=merge_up_thresh,
verbose=True,
_callback_manager=callback_manager or ss.callback_manager
)
self.retriever = retriever
def retriever_to_engine(
self,
response_synthesizer: BaseSynthesizer,
callback_manager: CallbackManager | None=None
) -> None:
"""Create query engine from retriever."""
if (not hasattr(self, "retriever")):
msg = "Retriever must be extracted from document using `storage_to_retriever` before calling `retriver_to_engine`."
raise ValueError(msg)
engine = get_engine(
retriever=self.retriever,
response_synthesizer=response_synthesizer,
callback_manager=callback_manager or ss.callback_manager
)
self.engine = engine
# TODO(Jonathan Wang): Create Summarization Index and Engine.
def engine_to_sub_question_engine(self) -> None:
"""Convert a basic query engine into a sub-question query engine for handling complex, multi-step questions.
Args:
query_engine (BaseQueryEngine): The Base Query Engine to convert.
"""
if (not hasattr(self, "summary_oneline")):
msg = "One Line Summary must be created for the document before calling `engine_to_sub_query_engine`"
raise ValueError(msg)
elif (not hasattr(self, "engine")):
msg = "Basic Query Engine must be created before calling `engine_to_sub_query_engine`"
raise ValueError(msg)
sqe_tools = [
QueryEngineTool(
query_engine=self.engine, # TODO(Jonathan Wang): handle mulitple engines?
metadata=ToolMetadata(
name=(self.name + "simple query answerer"),
description=f"""A tool that answers simple questions about the following document: {self.summary_oneline}"""
)
)
# TODO(Jonathan Wang): add more tools
]
subquestion_engine = SubQuestionQueryEngine.from_defaults(
query_engine_tools=sqe_tools,
verbose=True,
use_async=True
)
self.subquestion_engine = subquestion_engine