from logger_config import setup_logger
from typing import Dict, Any, Optional, List, Union, Tuple
from dataclasses import dataclass
from enum import Enum
import json
import re
import traceback

logger = setup_logger()

class MessageState:
    def __init__(self):
        self.buffer = ""
        self.is_complete = False
        self.tool_outputs = []
        self.citations = []
        self.metadata = {}
        self.processed_events = set()
        self.current_message_id = None

class SSEParser:
    def __init__(self):
        self.logger = setup_logger("sse_parser")
        self.current_message = MessageState()
    
    def _extract_json_content(self, data: str) -> Optional[str]:
        """Extract JSON content from SSE data line"""
        if "data:" in data:
            return data.split("data:", 1)[1].strip()
        return None

    def _is_valid_json(self, content: str) -> bool:
        """Check if content is valid JSON"""
        try:
            json.loads(content)
            return True
        except json.JSONDecodeError:
            return False

    def _clean_mermaid_content(self, content: str) -> Optional[str]:
        """Clean and extract mermaid diagram content"""
        try:
            self.logger.debug(f"Starting mermaid content cleaning. Input type: {type(content)}")
            self.logger.debug(f"Initial content: {content[:200]}...")  # Log first 200 chars

            # Handle tool output format
            if isinstance(content, dict) and "tool_output" in content:
                self.logger.debug("Found tool_output in dict, extracting...")
                content = content["tool_output"]
                self.logger.debug(f"Extracted tool_output: {content[:200]}...")
            
            # Remove tool response prefix/suffix if present
            if isinstance(content, str):
                self.logger.debug("Processing string content...")
                if "tool response:" in content:
                    self.logger.debug("Found 'tool response:' prefix, removing...")
                    content = content.split("tool response:", 1)[1].strip()
                if content.endswith('.'):
                    content = content[:-1]
                self.logger.debug(f"After prefix/suffix removal: {content[:200]}...")
                
            # Parse JSON if present
            try:
                if isinstance(content, str):
                    self.logger.debug("Attempting to parse content as JSON...")
                    data = json.loads(content)
                    self.logger.debug(f"JSON parsed successfully. Keys: {data.keys()}")
                else:
                    data = content
                    self.logger.debug(f"Using content as data directly. Type: {type(data)}")
                    
                # Handle different mermaid output formats
                if "mermaid_output" in data:
                    self.logger.debug("Found mermaid_output format")
                    content = data["mermaid_output"]
                elif "mermaid_diagram" in data:
                    self.logger.debug("Found mermaid_diagram format")
                    content = data["mermaid_diagram"]
                    
                # If content is still JSON string, parse again
                if isinstance(content, str) and content.startswith('{'):
                    self.logger.debug("Content still appears to be JSON, attempting second parse...")
                    try:
                        data = json.loads(content)
                        if "mermaid_output" in data:
                            content = data["mermaid_output"]
                            self.logger.debug("Extracted mermaid_output from second parse")
                        elif "mermaid_diagram" in data:
                            content = data["mermaid_diagram"]
                            self.logger.debug("Extracted mermaid_diagram from second parse")
                    except Exception as e:
                        self.logger.debug(f"Second JSON parse failed: {str(e)}")
                    
            except json.JSONDecodeError as e:
                self.logger.debug(f"Initial JSON parse failed: {str(e)}")

            # Clean up markdown formatting
            if isinstance(content, str):
                self.logger.debug("Cleaning markdown formatting...")
                content = content.replace("```mermaid\n", "").replace("\n```", "")
                content = content.strip()
                
                # Remove any remaining JSON artifacts
                if content.startswith('{'):
                    self.logger.debug("Attempting to clean remaining JSON artifacts...")
                    try:
                        data = json.loads(content)
                        if isinstance(data, dict):
                            content = next(iter(data.values()))
                            self.logger.debug("Extracted value from remaining JSON")
                    except Exception as e:
                        self.logger.debug(f"Final JSON cleanup failed: {str(e)}")
                    
                self.logger.debug(f"Final cleaned content: {content[:200]}...")
                return content
            
            self.logger.warning("Content not in string format after processing")
            return None
            
        except Exception as e:
            self.logger.error(f"Error cleaning mermaid content: {str(e)}")
            self.logger.error(f"Original content: {content}")
            self.logger.error(f"Stack trace: {traceback.format_exc()}")
            return None

    def parse_sse_event(self, data: str) -> Optional[Dict]:
        """Parse SSE event data and format for frontend consumption"""
        try:
            self.logger.debug(f"Parsing SSE event. Raw data length: {len(data)}")
            
            # Clean up the data format - remove extra data: prefixes
            data = data.replace('data: data:', 'data:').replace('\r\n', '\n')
            
            # Extract JSON content from SSE data
            json_content = self._extract_json_content(data)
            if not json_content:
                self.logger.debug("No JSON content found in SSE data")
                return None
                
            # Handle text-wrapped JSON
            if json_content.startswith('{"text":'):
                try:
                    wrapper = json.loads(json_content)
                    json_content = wrapper.get("text", "")
                except:
                    pass

            self.logger.debug(f"Cleaned JSON content: {json_content[:200]}...")
            
            # Parse XML content if present
            if '<agent_response>' in json_content:
                return self._parse_xml_response(json_content)
            
            # Parse JSON content
            try:
                parsed_data = json.loads(json_content)
                self.logger.debug(f"Parsed data keys: {parsed_data.keys()}")
                
                # Handle tool outputs
                if any(key in parsed_data for key in ['mermaid_output', 'mermaid_diagram']):
                    return {
                        "type": "tool_output",
                        "tool": "mermaid",
                        "content": self._clean_mermaid_content(json_content)
                    }
                    
                return parsed_data
                
            except json.JSONDecodeError:
                self.logger.debug("Failed to parse as JSON, treating as raw content")
                return {
                    "type": "message",
                    "content": json_content
                }
                
        except Exception as e:
            self.logger.error(f"Parse error: {str(e)}")
            self.logger.error(f"Raw data: {data}")
            self.logger.error(f"Stack trace: {traceback.format_exc()}")
            return None

    def _parse_xml_response(self, content: str) -> Optional[Dict]:
        """Parse XML response format"""
        try:
            # Extract message content
            message_match = re.search(r'<message>(.*?)</message>', content, re.DOTALL)
            if message_match:
                return {
                    "type": "message",
                    "content": message_match.group(1).strip()
                }
                
            # Extract tool output content
            tool_match = re.search(r'<tool_output.*?>(.*?)</tool_output>', content, re.DOTALL)
            if tool_match:
                tool_content = tool_match.group(1)
                if 'mermaid' in content.lower():
                    return {
                        "type": "tool_output",
                        "tool": "mermaid",
                        "content": self._clean_mermaid_content(tool_content)
                    }
                    
            return None
            
        except Exception as e:
            self.logger.error(f"XML parse error: {str(e)}")
            self.logger.error(f"Content: {content}")
            return None

    def _process_observation(self, data: Dict) -> Dict:
        """Process observation content with special handling for tool outputs"""
        try:
            observation = data.get("observation")
            if observation and isinstance(observation, str):
                # Handle tool-specific content
                if "mermaid_diagram" in observation:
                    cleaned_content = self.clean_mermaid_content(observation)
                    if cleaned_content not in [t.get("content") for t in self.current_message.tool_outputs]:
                        self.current_message.tool_outputs.append({
                            "type": "mermaid_diagram",
                            "content": cleaned_content
                        })
                        data["observation"] = json.dumps({
                            "mermaid_diagram": cleaned_content
                        })
                elif self._is_valid_json(observation):
                    # Handle other tool outputs
                    try:
                        tool_data = json.loads(observation)
                        if isinstance(tool_data, dict):
                            for tool_name, tool_output in tool_data.items():
                                if tool_output not in [t.get("content") for t in self.current_message.tool_outputs]:
                                    self.current_message.tool_outputs.append({
                                        "type": tool_name,
                                        "content": tool_output
                                    })
                    except json.JSONDecodeError:
                        pass
        except Exception as e:
            self.logger.error(f"Error processing observation: {str(e)}")
        return data

    def _handle_message_end(self, data: Dict) -> None:
        """Handle message end event and cleanup state"""
        self.current_message.citations = data.get("retriever_resources", [])
        self.current_message.metadata = data.get("metadata", {})
        self.current_message.metadata["tool_outputs"] = self.current_message.tool_outputs
        self.current_message.is_complete = True

    def clean_mermaid_content(self, content: str) -> str:
        """Clean and format mermaid diagram content"""
        try:
            # Remove markdown and JSON formatting
            content = re.sub(r'```mermaid\s*|\s*```', '', content)
            content = re.sub(r'tool response:.*?{', '{', content)
            content = re.sub(r'}\s*\.$', '}', content)
            
            # Parse JSON if present
            if content.strip().startswith('{'):
                try:
                    content_dict = json.loads(content)
                    if isinstance(content_dict, dict):
                        content = content_dict.get("mermaid_diagram", content)
                except:
                    pass
            
            return content.strip()
            
        except Exception as e:
            self.logger.error(f"Error cleaning mermaid content: {e}")
            return content