#####################################################
### DOCUMENT PROCESSOR [OBSERVATION/LOGGING]
#####################################################
# Jonathan Wang
# ABOUT:
# This project creates an app to chat with PDFs.
# This is the Observation and Logging
# to see the actions undertaken in the RAG pipeline.
#####################################################
## TODOS:
# Why does FullRAGEventHandler keep producing duplicate output?
#####################################################
## IMPORTS:
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, ClassVar, Sequence
import streamlit as st
# Callbacks
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler
# Pretty Printing
# from llama_index.core.response.notebook_utils import display_source_node
# End user handler
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.event_handlers import BaseEventHandler
from llama_index.core.instrumentation.events.agent import (
AgentChatWithStepEndEvent,
AgentChatWithStepStartEvent,
AgentRunStepEndEvent,
AgentRunStepStartEvent,
AgentToolCallEvent,
)
from llama_index.core.instrumentation.events.chat_engine import (
StreamChatDeltaReceivedEvent,
StreamChatErrorEvent,
)
from llama_index.core.instrumentation.events.embedding import (
EmbeddingEndEvent,
EmbeddingStartEvent,
)
from llama_index.core.instrumentation.events.llm import (
LLMChatEndEvent,
LLMChatInProgressEvent,
LLMChatStartEvent,
LLMCompletionEndEvent,
LLMCompletionStartEvent,
LLMPredictEndEvent,
LLMPredictStartEvent,
LLMStructuredPredictEndEvent,
LLMStructuredPredictStartEvent,
)
from llama_index.core.instrumentation.events.query import (
QueryEndEvent,
QueryStartEvent,
)
from llama_index.core.instrumentation.events.rerank import (
ReRankEndEvent,
ReRankStartEvent,
)
from llama_index.core.instrumentation.events.retrieval import (
RetrievalEndEvent,
RetrievalStartEvent,
)
from llama_index.core.instrumentation.events.span import (
SpanDropEvent,
)
from llama_index.core.instrumentation.events.synthesis import (
# GetResponseEndEvent,
GetResponseStartEvent,
SynthesizeEndEvent,
SynthesizeStartEvent,
)
from llama_index.core.instrumentation.span import SimpleSpan
from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler
from treelib import Tree
if TYPE_CHECKING:
from llama_index.core.instrumentation.dispatcher import Dispatcher
from llama_index.core.instrumentation.events import BaseEvent
from llama_index.core.schema import BaseNode, NodeWithScore
#####################################################
## Code
logger = logging.getLogger(__name__)
@st.cache_resource
def get_callback_manager() -> CallbackManager:
"""Create the callback manager for the code."""
return CallbackManager([LlamaDebugHandler()])
def display_source_node(source_node: NodeWithScore, max_length: int = 100) -> str:
source_text = source_node.node.get_content().strip()
source_text = source_text[:max_length] + "..." if len(source_text) > max_length else source_text
return (
f"**Node ID:** {source_node.node.node_id}
"
f"**Similarity:** {source_node.score}
"
f"**Text:** {source_text}
"
)
class RAGEventHandler(BaseEventHandler):
"""Pruned RAG Event Handler."""
# events: List[BaseEvent] = [] # TODO: handle removing historical events if they're too old.
@classmethod
def class_name(cls) -> str:
"""Class name."""
return "RAGEventHandler"
def handle(self, event: BaseEvent, **kwargs: Any) -> None:
"""Logic for handling event."""
print("-----------------------")
# all events have these attributes
print(event.id_)
print(event.timestamp)
print(event.span_id)
# event specific attributes
if isinstance(event, LLMChatStartEvent):
# initial
print(event.messages)
print(event.additional_kwargs)
print(event.model_dict)
elif isinstance(event, LLMChatInProgressEvent):
# streaming
print(event.response.delta)
elif isinstance(event, LLMChatEndEvent):
# final response
print(event.response)
# self.events.append(event)
print("-----------------------")
class FullRAGEventHandler(BaseEventHandler):
"""RAG event handler. Built off the example custom event handler.
In general, logged events are treated as single events in a point in time,
that link to a span. The span is a collection of events that are related to
a single task. The span is identified by a unique span_id.
While events are independent, there is some hierarchy.
For example, in query_engine.query() call with a reranker attached:
- QueryStartEvent
- RetrievalStartEvent
- EmbeddingStartEvent
- EmbeddingEndEvent
- RetrievalEndEvent
- RerankStartEvent
- RerankEndEvent
- SynthesizeStartEvent
- GetResponseStartEvent
- LLMPredictStartEvent
- LLMChatStartEvent
- LLMChatEndEvent
- LLMPredictEndEvent
- GetResponseEndEvent
- SynthesizeEndEvent
- QueryEndEvent
"""
events: ClassVar[list[BaseEvent]] = []
@classmethod
def class_name(cls) -> str:
"""Class name."""
return "RAGEventHandler"
def _print_event_nodes(self, event_nodes: Sequence[NodeWithScore | BaseNode]) -> str:
"""Print a list of nodes nicely."""
output_str = "["
for node in event_nodes:
output_str += (str(display_source_node(node, 1000)) + "\n")
output_str += "* * * * * * * * * * * *"
output_str += "]"
return (output_str)
def handle(self, event: BaseEvent, **kwargs: Any) -> None:
"""Logic for handling event."""
logger.info("-----------------------")
# all events have these attributes
logger.info(event.id_)
logger.info(event.timestamp)
logger.info(event.span_id)
# event specific attributes
logger.info(f"Event type: {event.class_name()}")
if isinstance(event, AgentRunStepStartEvent):
# logger.info(event.task_id)
logger.info(event.step)
logger.info(event.input)
if isinstance(event, AgentRunStepEndEvent):
logger.info(event.step_output)
if isinstance(event, AgentChatWithStepStartEvent):
logger.info(event.user_msg)
if isinstance(event, AgentChatWithStepEndEvent):
logger.info(event.response)
if isinstance(event, AgentToolCallEvent):
logger.info(event.arguments)
logger.info(event.tool.name)
logger.info(event.tool.description)
if isinstance(event, StreamChatDeltaReceivedEvent):
logger.info(event.delta)
if isinstance(event, StreamChatErrorEvent):
logger.info(event.exception)
if isinstance(event, EmbeddingStartEvent):
logger.info(event.model_dict)
if isinstance(event, EmbeddingEndEvent):
logger.info(event.chunks)
logger.info(event.embeddings[0][:5]) # avoid printing all embeddings
if isinstance(event, LLMPredictStartEvent):
logger.info(event.template)
logger.info(event.template_args)
if isinstance(event, LLMPredictEndEvent):
logger.info(event.output)
if isinstance(event, LLMStructuredPredictStartEvent):
logger.info(event.template)
logger.info(event.template_args)
logger.info(event.output_cls)
if isinstance(event, LLMStructuredPredictEndEvent):
logger.info(event.output)
if isinstance(event, LLMCompletionStartEvent):
logger.info(event.model_dict)
logger.info(event.prompt)
logger.info(event.additional_kwargs)
if isinstance(event, LLMCompletionEndEvent):
logger.info(event.response)
logger.info(event.prompt)
if isinstance(event, LLMChatInProgressEvent):
logger.info(event.messages)
logger.info(event.response)
if isinstance(event, LLMChatStartEvent):
logger.info(event.messages)
logger.info(event.additional_kwargs)
logger.info(event.model_dict)
if isinstance(event, LLMChatEndEvent):
logger.info(event.messages)
logger.info(event.response)
if isinstance(event, RetrievalStartEvent):
logger.info(event.str_or_query_bundle)
if isinstance(event, RetrievalEndEvent):
logger.info(event.str_or_query_bundle)
# logger.info(event.nodes)
logger.info(self._print_event_nodes(event.nodes))
if isinstance(event, ReRankStartEvent):
logger.info(event.query)
# logger.info(event.nodes)
for node in event.nodes:
logger.info(display_source_node(node))
logger.info(event.top_n)
logger.info(event.model_name)
if isinstance(event, ReRankEndEvent):
# logger.info(event.nodes)
logger.info(self._print_event_nodes(event.nodes))
if isinstance(event, QueryStartEvent):
logger.info(event.query)
if isinstance(event, QueryEndEvent):
logger.info(event.response)
logger.info(event.query)
if isinstance(event, SpanDropEvent):
logger.info(event.err_str)
if isinstance(event, SynthesizeStartEvent):
logger.info(event.query)
if isinstance(event, SynthesizeEndEvent):
logger.info(event.response)
logger.info(event.query)
if isinstance(event, GetResponseStartEvent):
logger.info(event.query_str)
self.events.append(event)
logger.info("-----------------------")
def _get_events_by_span(self) -> dict[str, list[BaseEvent]]:
events_by_span: dict[str, list[BaseEvent]] = {}
for event in self.events:
if event.span_id in events_by_span:
events_by_span[event.span_id].append(event)
elif (event.span_id is not None):
events_by_span[event.span_id] = [event]
return events_by_span
def _get_event_span_trees(self) -> list[Tree]:
events_by_span = self._get_events_by_span()
trees = []
tree = Tree()
for span, sorted_events in events_by_span.items():
# create root node i.e. span node
tree.create_node(
tag=f"{span} (SPAN)",
identifier=span,
parent=None,
data=sorted_events[0].timestamp,
)
for event in sorted_events:
tree.create_node(
tag=f"{event.class_name()}: {event.id_}",
identifier=event.id_,
parent=event.span_id,
data=event.timestamp,
)
trees.append(tree)
tree = Tree()
return trees
def print_event_span_trees(self) -> None:
"""View trace trees."""
trees = self._get_event_span_trees()
for tree in trees:
logger.info(
tree.show(
stdout=False, sorting=True, key=lambda node: node.data
)
)
logger.info("")
class RAGSpanHandler(BaseSpanHandler[SimpleSpan]):
span_dict: dict = {}
@classmethod
def class_name(cls) -> str:
"""Class name."""
return "ExampleSpanHandler"
def new_span(
self,
id_: str,
bound_args: Any,
instance: Any | None = None,
parent_span_id: str | None = None,
**kwargs: Any,
) -> SimpleSpan | None:
"""Create a span."""
# logic for creating a new MyCustomSpan
if id_ not in self.span_dict:
self.span_dict[id_] = []
self.span_dict[id_].append(
SimpleSpan(id_=id_, parent_id=parent_span_id)
)
def prepare_to_exit_span(
self,
id_: str,
bound_args: Any,
instance: Any | None = None,
result: Any | None = None,
**kwargs: Any,
) -> Any:
"""Logic for preparing to exit a span."""
# if id in self.span_dict:
# return self.span_dict[id].pop()
def prepare_to_drop_span(
self,
id_: str,
bound_args: Any,
instance: Any | None = None,
err: BaseException | None = None,
**kwargs: Any,
) -> Any:
"""Logic for preparing to drop a span."""
# if id in self.span_dict:
# return self.span_dict[id].pop()
def get_obs() -> Dispatcher:
"""Get observability for the RAG pipeline."""
dispatcher = get_dispatcher()
event_handler = RAGEventHandler()
span_handler = RAGSpanHandler()
dispatcher.add_event_handler(event_handler)
dispatcher.add_span_handler(span_handler)
return dispatcher