from enum import Enum
import logging
from typing import List
import os
import re
from typing import List
from dotenv import load_dotenv
from openai import OpenAI
import phoenix as px
import llama_index
from llama_index.core.llms import ChatMessage, MessageRole

load_dotenv()


class IndexBuilder:
    def __init__(self, vdb_collection_name, embed_model, is_load_from_vector_store=False):
        self.documents = None
        self.vdb_collection_name = vdb_collection_name
        self.embed_model = embed_model 
        self.index = None
        self.is_load_from_vector_store = is_load_from_vector_store
        self.build_index()

    def _load_doucments(self):
        pass

    def _setup_service_context(self):
        print("Using global service context...")

    def _setup_vector_store(self):
        print("Setup vector store...")

    def _setup_index(self):
        if not self.is_load_from_vector_store and self.documents is None:
            raise ValueError("No documents provided for index building.")
        print("Building Index")

    def build_index(self):
        if self.is_load_from_vector_store:
            self._setup_service_context()
            self._setup_vector_store()
            self._setup_index()
            return
        self._load_doucments()
        self._setup_service_context()
        self._setup_vector_store()
        self._setup_index()


class Chatbot:
    SYSTEM_PROMPT = ""
    DENIED_ANSWER_PROMPT = ""
    CHAT_EXAMPLES = []

    def __init__(self, model_name, index_builder: IndexBuilder, llm=None):
        self.model_name = model_name
        self.index_builder = index_builder
        self.llm = llm

        self.documents = None
        self.index = None
        self.chat_engine = None
        self.service_context = None
        self.vector_store = None
        self.tools = None

        self._setup_logger()
        self._setup_chatbot()


    def _setup_logger(self):
        logs_dir = 'logs'
        if not os.path.exists(logs_dir):
            os.makedirs(logs_dir)  # Step 3: Create logs directory

        logging.basicConfig(
            filename=os.path.join(logs_dir, 'chatbot.log'),
            filemode='a',
            format='%(asctime)s - %(levelname)s - %(message)s',
            level=logging.INFO
        )
        self.logger = logging.getLogger(__name__)

    def _setup_chatbot(self):
        # self._setup_observer()
        self._setup_index()
        self._setup_query_engine()
        self._setup_tools()
        self._setup_chat_engine()

    def _setup_observer(self):
        px.launch_app()
        llama_index.set_global_handler("arize_phoenix")

    def _setup_index(self):
        self.index = self.index_builder.index
        print("Inherited index builder")

    def _setup_query_engine(self):
        if self.index is None:
            raise ValueError("No index built")
        pass
        print("Setup query engine...")

    def _setup_tools(self):
        pass
        print("Setup tools...")

    def _setup_chat_engine(self):
        if self.index is None:
            raise ValueError("No index built")
        pass
        print("Setup chat engine...")

    def stream_chat(self, message, history):
        self.logger.info(history)
        self.logger.info(self.convert_to_chat_messages(history))
        if len(history) > 10:
            yield "Thank you for using AweSumCare. I'm sorry I can't answer your question now, but I'm still learning. Please try to ask me something else.\n感謝使用安心三寶。現時未能回答你的問題,請稍後再試。"
            return
        response = self.chat_engine.stream_chat(
            message, chat_history=self.convert_to_chat_messages(history)
        )
        # Stream tokens as they are generated
        partial_message = ""
        for token in response.response_gen:
            partial_message += token
            yield partial_message

    def convert_to_chat_messages(self, history: List[List[str]]) -> List[ChatMessage]:
        chat_messages = [ChatMessage(
            role=MessageRole.SYSTEM, content=self.SYSTEM_PROMPT)]
        for conversation in history[-3:]:
            for index, message in enumerate(conversation):
                role = MessageRole.USER if index % 2 == 0 else MessageRole.ASSISTANT
                clean_message = re.sub(
                    r"\n \n\n---\n\n參考: \n.*$", "", message, flags=re.DOTALL)
                chat_messages.append(ChatMessage(
                    role=role, content=clean_message.strip()))
        return chat_messages

    def predict_with_rag(self, message, history):
        return self.stream_chat(message, history)

    # barebone chatgpt methods, shared across all chatbot instance
    def _invoke_chatgpt(self, history, message, is_include_system_prompt=False):
        openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        history_openai_format = []
        if is_include_system_prompt:
            history_openai_format.append(
                {"role": "system", "content": self.SYSTEM_PROMPT})
        for human, assistant in history:
            history_openai_format.append({"role": "user", "content": human})
            history_openai_format.append(
                {"role": "assistant", "content": assistant})
        history_openai_format.append({"role": "user", "content": message})

        stream = openai_client.chat.completions.create(
            model=self.model_name,
            messages=history_openai_format,
            temperature=1.0,
            stream=True)

        partial_message = ""
        for part in stream:
            partial_message += part.choices[0].delta.content or ""
            yield partial_message

    # For 'With Prompt Wrapper' - Add system prompt, no Pinecone
    def predict_with_prompt_wrapper(self, message, history):
        yield from self._invoke_chatgpt(history, message, is_include_system_prompt=True)

    # For 'Vanilla ChatGPT' - No system prompt
    def predict_vanilla_chatgpt(self, message, history):
        yield from self._invoke_chatgpt(history, message)