Spaces:
Paused
Paused
##################################################### | |
### DOCUMENT PROCESSOR [OBSERVATION/LOGGING] | |
##################################################### | |
# Jonathan Wang | |
# ABOUT: | |
# This project creates an app to chat with PDFs. | |
# This is the Observation and Logging | |
# to see the actions undertaken in the RAG pipeline. | |
##################################################### | |
## TODOS: | |
# Why does FullRAGEventHandler keep producing duplicate output? | |
##################################################### | |
## IMPORTS: | |
from __future__ import annotations | |
import logging | |
from typing import TYPE_CHECKING, Any, ClassVar, Sequence | |
import streamlit as st | |
# Callbacks | |
from llama_index.core.callbacks import CallbackManager, LlamaDebugHandler | |
# Pretty Printing | |
# from llama_index.core.response.notebook_utils import display_source_node | |
# End user handler | |
from llama_index.core.instrumentation import get_dispatcher | |
from llama_index.core.instrumentation.event_handlers import BaseEventHandler | |
from llama_index.core.instrumentation.events.agent import ( | |
AgentChatWithStepEndEvent, | |
AgentChatWithStepStartEvent, | |
AgentRunStepEndEvent, | |
AgentRunStepStartEvent, | |
AgentToolCallEvent, | |
) | |
from llama_index.core.instrumentation.events.chat_engine import ( | |
StreamChatDeltaReceivedEvent, | |
StreamChatErrorEvent, | |
) | |
from llama_index.core.instrumentation.events.embedding import ( | |
EmbeddingEndEvent, | |
EmbeddingStartEvent, | |
) | |
from llama_index.core.instrumentation.events.llm import ( | |
LLMChatEndEvent, | |
LLMChatInProgressEvent, | |
LLMChatStartEvent, | |
LLMCompletionEndEvent, | |
LLMCompletionStartEvent, | |
LLMPredictEndEvent, | |
LLMPredictStartEvent, | |
LLMStructuredPredictEndEvent, | |
LLMStructuredPredictStartEvent, | |
) | |
from llama_index.core.instrumentation.events.query import ( | |
QueryEndEvent, | |
QueryStartEvent, | |
) | |
from llama_index.core.instrumentation.events.rerank import ( | |
ReRankEndEvent, | |
ReRankStartEvent, | |
) | |
from llama_index.core.instrumentation.events.retrieval import ( | |
RetrievalEndEvent, | |
RetrievalStartEvent, | |
) | |
from llama_index.core.instrumentation.events.span import ( | |
SpanDropEvent, | |
) | |
from llama_index.core.instrumentation.events.synthesis import ( | |
# GetResponseEndEvent, | |
GetResponseStartEvent, | |
SynthesizeEndEvent, | |
SynthesizeStartEvent, | |
) | |
from llama_index.core.instrumentation.span import SimpleSpan | |
from llama_index.core.instrumentation.span_handlers.base import BaseSpanHandler | |
from treelib import Tree | |
if TYPE_CHECKING: | |
from llama_index.core.instrumentation.dispatcher import Dispatcher | |
from llama_index.core.instrumentation.events import BaseEvent | |
from llama_index.core.schema import BaseNode, NodeWithScore | |
##################################################### | |
## Code | |
logger = logging.getLogger(__name__) | |
def get_callback_manager() -> CallbackManager: | |
"""Create the callback manager for the code.""" | |
return CallbackManager([LlamaDebugHandler()]) | |
def display_source_node(source_node: NodeWithScore, max_length: int = 100) -> str: | |
source_text = source_node.node.get_content().strip() | |
source_text = source_text[:max_length] + "..." if len(source_text) > max_length else source_text | |
return ( | |
f"**Node ID:** {source_node.node.node_id}<br>" | |
f"**Similarity:** {source_node.score}<br>" | |
f"**Text:** {source_text}<br>" | |
) | |
class RAGEventHandler(BaseEventHandler): | |
"""Pruned RAG Event Handler.""" | |
# events: List[BaseEvent] = [] # TODO: handle removing historical events if they're too old. | |
def class_name(cls) -> str: | |
"""Class name.""" | |
return "RAGEventHandler" | |
def handle(self, event: BaseEvent, **kwargs: Any) -> None: | |
"""Logic for handling event.""" | |
print("-----------------------") | |
# all events have these attributes | |
print(event.id_) | |
print(event.timestamp) | |
print(event.span_id) | |
# event specific attributes | |
if isinstance(event, LLMChatStartEvent): | |
# initial | |
print(event.messages) | |
print(event.additional_kwargs) | |
print(event.model_dict) | |
elif isinstance(event, LLMChatInProgressEvent): | |
# streaming | |
print(event.response.delta) | |
elif isinstance(event, LLMChatEndEvent): | |
# final response | |
print(event.response) | |
# self.events.append(event) | |
print("-----------------------") | |
class FullRAGEventHandler(BaseEventHandler): | |
"""RAG event handler. Built off the example custom event handler. | |
In general, logged events are treated as single events in a point in time, | |
that link to a span. The span is a collection of events that are related to | |
a single task. The span is identified by a unique span_id. | |
While events are independent, there is some hierarchy. | |
For example, in query_engine.query() call with a reranker attached: | |
- QueryStartEvent | |
- RetrievalStartEvent | |
- EmbeddingStartEvent | |
- EmbeddingEndEvent | |
- RetrievalEndEvent | |
- RerankStartEvent | |
- RerankEndEvent | |
- SynthesizeStartEvent | |
- GetResponseStartEvent | |
- LLMPredictStartEvent | |
- LLMChatStartEvent | |
- LLMChatEndEvent | |
- LLMPredictEndEvent | |
- GetResponseEndEvent | |
- SynthesizeEndEvent | |
- QueryEndEvent | |
""" | |
events: ClassVar[list[BaseEvent]] = [] | |
def class_name(cls) -> str: | |
"""Class name.""" | |
return "RAGEventHandler" | |
def _print_event_nodes(self, event_nodes: Sequence[NodeWithScore | BaseNode]) -> str: | |
"""Print a list of nodes nicely.""" | |
output_str = "[" | |
for node in event_nodes: | |
output_str += (str(display_source_node(node, 1000)) + "\n") | |
output_str += "* * * * * * * * * * * *" | |
output_str += "]" | |
return (output_str) | |
def handle(self, event: BaseEvent, **kwargs: Any) -> None: | |
"""Logic for handling event.""" | |
logger.info("-----------------------") | |
# all events have these attributes | |
logger.info(event.id_) | |
logger.info(event.timestamp) | |
logger.info(event.span_id) | |
# event specific attributes | |
logger.info(f"Event type: {event.class_name()}") | |
if isinstance(event, AgentRunStepStartEvent): | |
# logger.info(event.task_id) | |
logger.info(event.step) | |
logger.info(event.input) | |
if isinstance(event, AgentRunStepEndEvent): | |
logger.info(event.step_output) | |
if isinstance(event, AgentChatWithStepStartEvent): | |
logger.info(event.user_msg) | |
if isinstance(event, AgentChatWithStepEndEvent): | |
logger.info(event.response) | |
if isinstance(event, AgentToolCallEvent): | |
logger.info(event.arguments) | |
logger.info(event.tool.name) | |
logger.info(event.tool.description) | |
if isinstance(event, StreamChatDeltaReceivedEvent): | |
logger.info(event.delta) | |
if isinstance(event, StreamChatErrorEvent): | |
logger.info(event.exception) | |
if isinstance(event, EmbeddingStartEvent): | |
logger.info(event.model_dict) | |
if isinstance(event, EmbeddingEndEvent): | |
logger.info(event.chunks) | |
logger.info(event.embeddings[0][:5]) # avoid printing all embeddings | |
if isinstance(event, LLMPredictStartEvent): | |
logger.info(event.template) | |
logger.info(event.template_args) | |
if isinstance(event, LLMPredictEndEvent): | |
logger.info(event.output) | |
if isinstance(event, LLMStructuredPredictStartEvent): | |
logger.info(event.template) | |
logger.info(event.template_args) | |
logger.info(event.output_cls) | |
if isinstance(event, LLMStructuredPredictEndEvent): | |
logger.info(event.output) | |
if isinstance(event, LLMCompletionStartEvent): | |
logger.info(event.model_dict) | |
logger.info(event.prompt) | |
logger.info(event.additional_kwargs) | |
if isinstance(event, LLMCompletionEndEvent): | |
logger.info(event.response) | |
logger.info(event.prompt) | |
if isinstance(event, LLMChatInProgressEvent): | |
logger.info(event.messages) | |
logger.info(event.response) | |
if isinstance(event, LLMChatStartEvent): | |
logger.info(event.messages) | |
logger.info(event.additional_kwargs) | |
logger.info(event.model_dict) | |
if isinstance(event, LLMChatEndEvent): | |
logger.info(event.messages) | |
logger.info(event.response) | |
if isinstance(event, RetrievalStartEvent): | |
logger.info(event.str_or_query_bundle) | |
if isinstance(event, RetrievalEndEvent): | |
logger.info(event.str_or_query_bundle) | |
# logger.info(event.nodes) | |
logger.info(self._print_event_nodes(event.nodes)) | |
if isinstance(event, ReRankStartEvent): | |
logger.info(event.query) | |
# logger.info(event.nodes) | |
for node in event.nodes: | |
logger.info(display_source_node(node)) | |
logger.info(event.top_n) | |
logger.info(event.model_name) | |
if isinstance(event, ReRankEndEvent): | |
# logger.info(event.nodes) | |
logger.info(self._print_event_nodes(event.nodes)) | |
if isinstance(event, QueryStartEvent): | |
logger.info(event.query) | |
if isinstance(event, QueryEndEvent): | |
logger.info(event.response) | |
logger.info(event.query) | |
if isinstance(event, SpanDropEvent): | |
logger.info(event.err_str) | |
if isinstance(event, SynthesizeStartEvent): | |
logger.info(event.query) | |
if isinstance(event, SynthesizeEndEvent): | |
logger.info(event.response) | |
logger.info(event.query) | |
if isinstance(event, GetResponseStartEvent): | |
logger.info(event.query_str) | |
self.events.append(event) | |
logger.info("-----------------------") | |
def _get_events_by_span(self) -> dict[str, list[BaseEvent]]: | |
events_by_span: dict[str, list[BaseEvent]] = {} | |
for event in self.events: | |
if event.span_id in events_by_span: | |
events_by_span[event.span_id].append(event) | |
elif (event.span_id is not None): | |
events_by_span[event.span_id] = [event] | |
return events_by_span | |
def _get_event_span_trees(self) -> list[Tree]: | |
events_by_span = self._get_events_by_span() | |
trees = [] | |
tree = Tree() | |
for span, sorted_events in events_by_span.items(): | |
# create root node i.e. span node | |
tree.create_node( | |
tag=f"{span} (SPAN)", | |
identifier=span, | |
parent=None, | |
data=sorted_events[0].timestamp, | |
) | |
for event in sorted_events: | |
tree.create_node( | |
tag=f"{event.class_name()}: {event.id_}", | |
identifier=event.id_, | |
parent=event.span_id, | |
data=event.timestamp, | |
) | |
trees.append(tree) | |
tree = Tree() | |
return trees | |
def print_event_span_trees(self) -> None: | |
"""View trace trees.""" | |
trees = self._get_event_span_trees() | |
for tree in trees: | |
logger.info( | |
tree.show( | |
stdout=False, sorting=True, key=lambda node: node.data | |
) | |
) | |
logger.info("") | |
class RAGSpanHandler(BaseSpanHandler[SimpleSpan]): | |
span_dict: dict = {} | |
def class_name(cls) -> str: | |
"""Class name.""" | |
return "ExampleSpanHandler" | |
def new_span( | |
self, | |
id_: str, | |
bound_args: Any, | |
instance: Any | None = None, | |
parent_span_id: str | None = None, | |
**kwargs: Any, | |
) -> SimpleSpan | None: | |
"""Create a span.""" | |
# logic for creating a new MyCustomSpan | |
if id_ not in self.span_dict: | |
self.span_dict[id_] = [] | |
self.span_dict[id_].append( | |
SimpleSpan(id_=id_, parent_id=parent_span_id) | |
) | |
def prepare_to_exit_span( | |
self, | |
id_: str, | |
bound_args: Any, | |
instance: Any | None = None, | |
result: Any | None = None, | |
**kwargs: Any, | |
) -> Any: | |
"""Logic for preparing to exit a span.""" | |
# if id in self.span_dict: | |
# return self.span_dict[id].pop() | |
def prepare_to_drop_span( | |
self, | |
id_: str, | |
bound_args: Any, | |
instance: Any | None = None, | |
err: BaseException | None = None, | |
**kwargs: Any, | |
) -> Any: | |
"""Logic for preparing to drop a span.""" | |
# if id in self.span_dict: | |
# return self.span_dict[id].pop() | |
def get_obs() -> Dispatcher: | |
"""Get observability for the RAG pipeline.""" | |
dispatcher = get_dispatcher() | |
event_handler = RAGEventHandler() | |
span_handler = RAGSpanHandler() | |
dispatcher.add_event_handler(event_handler) | |
dispatcher.add_span_handler(span_handler) | |
return dispatcher | |