Spaces:
Sleeping
Sleeping
from logger_config import setup_logger | |
from typing import Dict, Any, Optional, List, Union, Tuple | |
from dataclasses import dataclass | |
from enum import Enum | |
import json | |
import re | |
import traceback | |
logger = setup_logger() | |
class MessageState: | |
def __init__(self): | |
self.buffer = "" | |
self.is_complete = False | |
self.tool_outputs = [] | |
self.citations = [] | |
self.metadata = {} | |
self.processed_events = set() | |
self.current_message_id = None | |
class SSEParser: | |
def __init__(self): | |
self.logger = setup_logger("sse_parser") | |
self.current_message = MessageState() | |
def _extract_json_content(self, data: str) -> Optional[str]: | |
"""Extract JSON content from SSE data line""" | |
if "data:" in data: | |
return data.split("data:", 1)[1].strip() | |
return None | |
def _is_valid_json(self, content: str) -> bool: | |
"""Check if content is valid JSON""" | |
try: | |
json.loads(content) | |
return True | |
except json.JSONDecodeError: | |
return False | |
def _clean_mermaid_content(self, content: str) -> Optional[str]: | |
"""Clean and extract mermaid diagram content""" | |
try: | |
self.logger.debug(f"Starting mermaid content cleaning. Input type: {type(content)}") | |
self.logger.debug(f"Initial content: {content[:200]}...") # Log first 200 chars | |
# Handle tool output format | |
if isinstance(content, dict) and "tool_output" in content: | |
self.logger.debug("Found tool_output in dict, extracting...") | |
content = content["tool_output"] | |
self.logger.debug(f"Extracted tool_output: {content[:200]}...") | |
# Remove tool response prefix/suffix if present | |
if isinstance(content, str): | |
self.logger.debug("Processing string content...") | |
if "tool response:" in content: | |
self.logger.debug("Found 'tool response:' prefix, removing...") | |
content = content.split("tool response:", 1)[1].strip() | |
if content.endswith('.'): | |
content = content[:-1] | |
self.logger.debug(f"After prefix/suffix removal: {content[:200]}...") | |
# Parse JSON if present | |
try: | |
if isinstance(content, str): | |
self.logger.debug("Attempting to parse content as JSON...") | |
data = json.loads(content) | |
self.logger.debug(f"JSON parsed successfully. Keys: {data.keys()}") | |
else: | |
data = content | |
self.logger.debug(f"Using content as data directly. Type: {type(data)}") | |
# Handle different mermaid output formats | |
if "mermaid_output" in data: | |
self.logger.debug("Found mermaid_output format") | |
content = data["mermaid_output"] | |
elif "mermaid_diagram" in data: | |
self.logger.debug("Found mermaid_diagram format") | |
content = data["mermaid_diagram"] | |
# If content is still JSON string, parse again | |
if isinstance(content, str) and content.startswith('{'): | |
self.logger.debug("Content still appears to be JSON, attempting second parse...") | |
try: | |
data = json.loads(content) | |
if "mermaid_output" in data: | |
content = data["mermaid_output"] | |
self.logger.debug("Extracted mermaid_output from second parse") | |
elif "mermaid_diagram" in data: | |
content = data["mermaid_diagram"] | |
self.logger.debug("Extracted mermaid_diagram from second parse") | |
except Exception as e: | |
self.logger.debug(f"Second JSON parse failed: {str(e)}") | |
except json.JSONDecodeError as e: | |
self.logger.debug(f"Initial JSON parse failed: {str(e)}") | |
# Clean up markdown formatting | |
if isinstance(content, str): | |
self.logger.debug("Cleaning markdown formatting...") | |
content = content.replace("```mermaid\n", "").replace("\n```", "") | |
content = content.strip() | |
# Remove any remaining JSON artifacts | |
if content.startswith('{'): | |
self.logger.debug("Attempting to clean remaining JSON artifacts...") | |
try: | |
data = json.loads(content) | |
if isinstance(data, dict): | |
content = next(iter(data.values())) | |
self.logger.debug("Extracted value from remaining JSON") | |
except Exception as e: | |
self.logger.debug(f"Final JSON cleanup failed: {str(e)}") | |
self.logger.debug(f"Final cleaned content: {content[:200]}...") | |
return content | |
self.logger.warning("Content not in string format after processing") | |
return None | |
except Exception as e: | |
self.logger.error(f"Error cleaning mermaid content: {str(e)}") | |
self.logger.error(f"Original content: {content}") | |
self.logger.error(f"Stack trace: {traceback.format_exc()}") | |
return None | |
def parse_sse_event(self, data: str) -> Optional[Dict]: | |
"""Parse SSE event data and format for frontend consumption""" | |
try: | |
self.logger.debug(f"Parsing SSE event. Raw data length: {len(data)}") | |
# Clean up the data format - remove extra data: prefixes | |
data = data.replace('data: data:', 'data:').replace('\r\n', '\n') | |
# Extract JSON content from SSE data | |
json_content = self._extract_json_content(data) | |
if not json_content: | |
self.logger.debug("No JSON content found in SSE data") | |
return None | |
# Handle text-wrapped JSON | |
if json_content.startswith('{"text":'): | |
try: | |
wrapper = json.loads(json_content) | |
json_content = wrapper.get("text", "") | |
except: | |
pass | |
self.logger.debug(f"Cleaned JSON content: {json_content[:200]}...") | |
# Parse XML content if present | |
if '<agent_response>' in json_content: | |
return self._parse_xml_response(json_content) | |
# Parse JSON content | |
try: | |
parsed_data = json.loads(json_content) | |
self.logger.debug(f"Parsed data keys: {parsed_data.keys()}") | |
# Handle tool outputs | |
if any(key in parsed_data for key in ['mermaid_output', 'mermaid_diagram']): | |
return { | |
"type": "tool_output", | |
"tool": "mermaid", | |
"content": self._clean_mermaid_content(json_content) | |
} | |
return parsed_data | |
except json.JSONDecodeError: | |
self.logger.debug("Failed to parse as JSON, treating as raw content") | |
return { | |
"type": "message", | |
"content": json_content | |
} | |
except Exception as e: | |
self.logger.error(f"Parse error: {str(e)}") | |
self.logger.error(f"Raw data: {data}") | |
self.logger.error(f"Stack trace: {traceback.format_exc()}") | |
return None | |
def _parse_xml_response(self, content: str) -> Optional[Dict]: | |
"""Parse XML response format""" | |
try: | |
# Extract message content | |
message_match = re.search(r'<message>(.*?)</message>', content, re.DOTALL) | |
if message_match: | |
return { | |
"type": "message", | |
"content": message_match.group(1).strip() | |
} | |
# Extract tool output content | |
tool_match = re.search(r'<tool_output.*?>(.*?)</tool_output>', content, re.DOTALL) | |
if tool_match: | |
tool_content = tool_match.group(1) | |
if 'mermaid' in content.lower(): | |
return { | |
"type": "tool_output", | |
"tool": "mermaid", | |
"content": self._clean_mermaid_content(tool_content) | |
} | |
return None | |
except Exception as e: | |
self.logger.error(f"XML parse error: {str(e)}") | |
self.logger.error(f"Content: {content}") | |
return None | |
def _process_observation(self, data: Dict) -> Dict: | |
"""Process observation content with special handling for tool outputs""" | |
try: | |
observation = data.get("observation") | |
if observation and isinstance(observation, str): | |
# Handle tool-specific content | |
if "mermaid_diagram" in observation: | |
cleaned_content = self.clean_mermaid_content(observation) | |
if cleaned_content not in [t.get("content") for t in self.current_message.tool_outputs]: | |
self.current_message.tool_outputs.append({ | |
"type": "mermaid_diagram", | |
"content": cleaned_content | |
}) | |
data["observation"] = json.dumps({ | |
"mermaid_diagram": cleaned_content | |
}) | |
elif self._is_valid_json(observation): | |
# Handle other tool outputs | |
try: | |
tool_data = json.loads(observation) | |
if isinstance(tool_data, dict): | |
for tool_name, tool_output in tool_data.items(): | |
if tool_output not in [t.get("content") for t in self.current_message.tool_outputs]: | |
self.current_message.tool_outputs.append({ | |
"type": tool_name, | |
"content": tool_output | |
}) | |
except json.JSONDecodeError: | |
pass | |
except Exception as e: | |
self.logger.error(f"Error processing observation: {str(e)}") | |
return data | |
def _handle_message_end(self, data: Dict) -> None: | |
"""Handle message end event and cleanup state""" | |
self.current_message.citations = data.get("retriever_resources", []) | |
self.current_message.metadata = data.get("metadata", {}) | |
self.current_message.metadata["tool_outputs"] = self.current_message.tool_outputs | |
self.current_message.is_complete = True | |
def clean_mermaid_content(self, content: str) -> str: | |
"""Clean and format mermaid diagram content""" | |
try: | |
# Remove markdown and JSON formatting | |
content = re.sub(r'```mermaid\s*|\s*```', '', content) | |
content = re.sub(r'tool response:.*?{', '{', content) | |
content = re.sub(r'}\s*\.$', '}', content) | |
# Parse JSON if present | |
if content.strip().startswith('{'): | |
try: | |
content_dict = json.loads(content) | |
if isinstance(content_dict, dict): | |
content = content_dict.get("mermaid_diagram", content) | |
except: | |
pass | |
return content.strip() | |
except Exception as e: | |
self.logger.error(f"Error cleaning mermaid content: {e}") | |
return content |