import streamlit as st
import os
import time
import json
import re
from typing import List, Literal, TypedDict
from transformers import AutoTokenizer
from tools.tools import toolsInfo
from gradio_client import Client
import constants as C
import utils as U

from openai import OpenAI
import anthropic
from groq import Groq

from dotenv import load_dotenv
load_dotenv()

ModelType = Literal["GPT4", "CLAUDE", "LLAMA"]
ModelConfig = TypedDict("ModelConfig", {
    "client": OpenAI | Groq | anthropic.Anthropic,
    "model": str,
    "max_context": int,
    "tokenizer": AutoTokenizer
})

modelType: ModelType = os.environ.get("MODEL_TYPE") or "LLAMA"

MODEL_CONFIG: dict[ModelType, ModelConfig] = {
    "GPT4": {
        "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
        "model": "gpt-4o-mini",
        "max_context": 128000,
        "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
    },
    "CLAUDE": {
        "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
        "model": "claude-3-5-sonnet-20240620",
        "max_context": 128000,
        "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
    },
    "LLAMA": {
        "client": Groq(api_key=os.environ.get("GROQ_API_KEY")),
        "model": "llama-3.1-70b-versatile",
        # "model": "llama-3.2-90b-text-preview",
        "tools_model": "llama3-groq-70b-8192-tool-use-preview",
        "max_context": 12800,  # intentionally reduced to 1/10th
        "tokenizer": AutoTokenizer.from_pretrained("Xenova/Meta-Llama-3.1-Tokenizer")
    }
}

client = MODEL_CONFIG[modelType]["client"]
MODEL = MODEL_CONFIG[modelType]["model"]
TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
tokenizer = MODEL_CONFIG[modelType]["tokenizer"]

isClaudeModel = modelType == "CLAUDE"


def __countTokens(text):
    text = str(text)
    tokens = tokenizer.encode(text, add_special_tokens=False)
    return len(tokens)


st.set_page_config(
    page_title="Dr Newtons PG research Ai",
    page_icon=C.AI_ICON,
)

st.markdown('<link rel="manifest" href="manifest.json">', unsafe_allow_html=True)


def __isInvalidResponse(response: str):
    if len(re.findall(r'\n((?!http)[a-z])', response)) > 3 and "```" not in response:
        U.pprint("new line followed by small case char")
        return True

    if len(re.findall(r'\b(\w+)(\s+\1){2,}\b', response)) > 1:
        U.pprint("lot of consecutive repeating words")
        return True

    if len(re.findall(r'\n\n', response)) > 20:
        U.pprint("lots of paragraphs")
        return True

    if C.EXCEPTION_KEYWORD in response:
        U.pprint("LLM API threw exception")
        return True

    if ('{\n    "questions"' in response) and (C.JSON_SEPARATOR not in response):
        U.pprint("JSON response without json separator")
        return True
    if ('{\n    "action"' in response) and (C.JSON_SEPARATOR not in response):
        U.pprint("JSON response without json separator")
        return True

    if response.startswith(C.JSON_SEPARATOR):
        U.pprint("only options with no text")
        return True


def __matchingKeywordsCount(keywords: List[str], text: str):
    return sum([
        1 if keyword in text else 0
        for keyword in keywords
    ])


def __getMessages():
    def getContextSize():
        currContextSize = __countTokens(C.SYSTEM_MSG) + __countTokens(st.session_state.messages) + 100
        U.pprint(f"{currContextSize=}")
        return currContextSize

    while getContextSize() > MAX_CONTEXT:
        U.pprint("Context size exceeded, removing first message")
        st.session_state.messages.pop(0)

    return st.session_state.messages


def __logLlmRequest(messagesFormatted: list, model: str):
    contextSize = __countTokens(messagesFormatted)
    U.pprint(f"{contextSize=} | {model}")
    # U.pprint(f"{messagesFormatted=}")


tools = [
    toolsInfo["getGoogleSearchResults"]["schema"],
]


def __showToolResponse(toolResponseDisplay: dict):
    msg = toolResponseDisplay.get("text")
    icon = toolResponseDisplay.get("icon")

    col1, col2 = st.columns([1, 20])
    with col1:
        st.image(
            icon or C.TOOL_ICON,
            width=30
        )
    with col2:
        if "`" not in msg:
            st.markdown(f"`{msg}`")
        else:
            st.markdown(msg)


def __addToolCallToMsgs(toolCall: dict):
    if isClaudeModel:
        st.session_state.messages.append(toolCall)
    else:
        st.session_state.messages.append(
            {
                "role": "assistant",
                "tool_calls": [
                    {
                        "id": toolCall.id,
                        "function": {
                            "name": toolCall.function.name,
                            "arguments": toolCall.function.arguments,
                        },
                        "type": toolCall.type,
                    }
                ],
            }
        )


def __processToolCalls(toolCalls):
    for toolCall in toolCalls:
        functionName = toolCall.function.name
        functionToCall = toolsInfo[functionName]["func"]
        functionArgsStr = toolCall.function.arguments
        U.pprint(f"{functionName=} | {functionArgsStr=}")
        functionArgs = json.loads(functionArgsStr)

        functionResult = functionToCall(**functionArgs)
        functionResponse = functionResult.get("response")
        responseDisplay = functionResult.get("display")
        U.pprint(f"{functionResponse=}")

        if responseDisplay:
            __showToolResponse(responseDisplay)
            st.session_state.toolResponseDisplay = responseDisplay

        __addToolCallToMsgs(toolCall)

        st.session_state.messages.append({
            "role": "tool",
            "tool_call_id": toolCall.id,
            "name": functionName,
            "content": functionResponse,
        })


def __processClaudeToolCalls(toolResponse):
    toolCall = toolResponse[1]

    functionName = toolCall.name
    functionToCall = toolsInfo[functionName]["func"]
    functionArgs = toolCall.input
    functionResult = functionToCall(**functionArgs)
    functionResponse = functionResult.get("response")
    responseDisplay = functionResult.get("display")
    U.pprint(f"{functionResponse=}")

    if responseDisplay:
        __showToolResponse(responseDisplay)
        st.session_state.toolResponseDisplay = responseDisplay

    __addToolCallToMsgs({
        "role": "assistant",
        "content": toolResponse
    })
    st.session_state.messages.append({
        "role": "user",
        "content": [{
            "type": "tool_result",
            "tool_use_id": toolCall.id,
            "content": functionResponse,
        }],
    })


def __dedupeToolCalls(toolCalls: list):
    toolCallsDict = {}
    for toolCall in toolCalls:
        funcName = toolCall.name if isClaudeModel else toolCall.function.name
        toolCallsDict[funcName] = toolCall
    dedupedToolCalls = list(toolCallsDict.values())

    if len(toolCalls) != len(dedupedToolCalls):
        U.pprint("Deduped tool calls!")
        U.pprint(f"{toolCalls=} -> {dedupedToolCalls=}")

    return dedupedToolCalls


def __getClaudeTools():
    claudeTools = []
    for tool in tools:
        funcInfo = tool["function"]
        name = funcInfo["name"]
        description = funcInfo["description"]
        schema = funcInfo["parameters"]
        claudeTools.append({
            "name": name,
            "description": description,
            "input_schema": schema,
        })

    return claudeTools


def __removeFunctionCall(response: str):
    pattern = r'<function=getGoogleSearchResults>\{"query": ".*?"\}<function>'
    return re.sub(pattern, '', response)


def predict(model: str = None, attempts=0):
    model = model or MODEL

    messagesFormatted = []
    try:
        if isClaudeModel:
            messagesFormatted.extend(__getMessages())
            __logLlmRequest(messagesFormatted, model)

            responseMessage = client.messages.create(
                model=model,
                messages=messagesFormatted,
                system=C.SYSTEM_MSG,
                temperature=0.5,
                max_tokens=4000,
                tools=__getClaudeTools()
            )

            responseMessageContent = responseMessage.content
            responseContent = responseMessageContent[0].text
            toolCalls = []
            if len(responseMessageContent) > 1:
                toolCalls = [responseMessageContent[1]]
        else:
            messagesFormatted = [{"role": "system", "content": C.SYSTEM_MSG}]
            messagesFormatted.extend(__getMessages())
            __logLlmRequest(messagesFormatted, model)

            response = client.chat.completions.create(
                model=model,
                messages=messagesFormatted,
                temperature=0.6,
                max_tokens=4000,
                stream=False,
                tools=tools
            )
            responseMessage = response.choices[0].message
            responseContent = responseMessage.content

            if responseContent and '<function=' in responseContent:
                U.pprint(f"Wrong toolCall response: {responseContent}")
                if attempts < 3:
                    U.pprint(f"Retrying...{attempts + 1}/3")
                    time.sleep(0.2)
                    return predict(model, attempts + 1)
                else:
                    responseContent = __removeFunctionCall(responseContent)
                    if "<function=" in responseContent:
                        U.pprint("Switching to TOOLS_MODEL")
                        return predict(TOOLS_MODEL)
            toolCalls = responseMessage.tool_calls

        # U.pprint(f"{responseMessage=}")
        # U.pprint(f"{responseContent=}")
        U.pprint(f"{toolCalls=}")

        if toolCalls:
            toolCalls = __dedupeToolCalls(toolCalls)
            U.pprint("Deduping done!")
            try:
                if isClaudeModel:
                    __processClaudeToolCalls(responseMessage.content)
                else:
                    __processToolCalls(toolCalls)
                return predict()
            except Exception as e:
                U.pprint(e)
        else:
            return responseContent
    except Exception as e:
        U.pprint(f"LLM API Error: {e}")
        return f"{C.EXCEPTION_KEYWORD} | {e}"


def __generateImage(prompt: str):
    fluxClient = Client("black-forest-labs/FLUX.1-schnell")
    result = fluxClient.predict(
            prompt=prompt,
            seed=0,
            randomize_seed=True,
            width=1024,
            height=768,
            num_inference_steps=4,
            api_name="/infer"
    )
    U.pprint(f"imageResult={result}")
    return result


def __resetButtonState():
    st.session_state.buttonValue = ""


if "ipAddress" not in st.session_state:
    st.session_state.ipAddress = st.context.headers.get("x-forwarded-for")

if "chatHistory" not in st.session_state:
    st.session_state.chatHistory = []

if "messages" not in st.session_state:
    st.session_state.messages = []

if "buttonValue" not in st.session_state:
    __resetButtonState()

st.session_state.toolResponseDisplay = {}

U.pprint("\n")
U.pprint("\n")

U.applyCommonStyles()
st.title("Dr Newtons Ai Client")


for chat in st.session_state.chatHistory:
    role = chat["role"]
    content = chat["content"]
    imagePath = chat.get("image")
    toolResponseDisplay = chat.get("toolResponseDisplay")
    avatar = C.AI_ICON if role == "assistant" else C.USER_ICON
    with st.chat_message(role, avatar=avatar):
        st.markdown(content)
        if toolResponseDisplay:
            __showToolResponse(toolResponseDisplay)

        if imagePath:
            st.image(imagePath)

# U.pprint(f"{st.session_state.buttonValue=}")
# U.pprint(f"{st.session_state.selectedStory=}")
# U.pprint(f"{st.session_state.startMsg=}")

if prompt := (
    st.chat_input("Ask anything ...")
    or st.session_state["buttonValue"]
):
    __resetButtonState()

    with st.chat_message("user", avatar=C.USER_ICON):
        st.markdown(prompt)
    U.pprint(f"{prompt=}")
    st.session_state.chatHistory.append({"role": "user", "content": prompt })
    st.session_state.messages.append({"role": "user", "content": prompt})

    with st.chat_message("assistant", avatar=C.AI_ICON):
        responseContainer = st.empty()

        def __printAndGetResponse():
            response = ""
            responseContainer.image(C.TEXT_LOADER)
            responseGenerator = predict()

            for chunk in responseGenerator:
                response += chunk
                if __isInvalidResponse(response):
                    U.pprint(f"InvalidResponse={response}")
                    return

                if C.JSON_SEPARATOR not in response:
                    responseContainer.markdown(response)

            return response

        response = __printAndGetResponse()
        while not response:
            U.pprint("Empty response. Retrying..")
            time.sleep(0.7)
            response = __printAndGetResponse()

        U.pprint(f"{response=}")

        def selectButton(optionLabel):
            st.session_state["buttonValue"] = optionLabel
            U.pprint(f"Selected: {optionLabel}")

        rawResponse = response
        responseParts = response.split(C.JSON_SEPARATOR)

        jsonStr = None
        if len(responseParts) > 1:
            [response, jsonStr] = responseParts

        imagePath = None
        # imageContainer = st.empty()
        # try:
        #     (imagePrompt, loaderText) = __getImagePromptDetails(prompt, response)
        #     if imagePrompt:
        #         imgContainer = imageContainer.container()
        #         imgContainer.write(
        #             f"""
        #             <div class='blinking code'>
        #             {loaderText}
        #             </div>
        #             """,
        #             unsafe_allow_html=True
        #         )
        #         # imgContainer.markdown(f"`{loaderText}`")
        #         imgContainer.image(C.IMAGE_LOADER)
        #         (imagePath, seed) = __generateImage(imagePrompt)
        #         imageContainer.image(imagePath)
        # except Exception as e:
        #     U.pprint(e)
        #     imageContainer.empty()

        toolResponseDisplay = st.session_state.toolResponseDisplay
        st.session_state.chatHistory.append({
            "role": "assistant",
            "content": response,
            "image": imagePath,
            "toolResponseDisplay": toolResponseDisplay
        })
        st.session_state.messages.append({
            "role": "assistant",
            "content": rawResponse,
        })

        if jsonStr:
            try:
                jsonStr = jsonStr.replace("```", "")
                json.loads(jsonStr)
                jsonObj = json.loads(jsonStr)
                questions = jsonObj.get("questions")
                action = jsonObj.get("action")

                if questions:
                    for option in questions:
                        st.button(
                            option["label"],
                            key=option["id"],
                            on_click=lambda label=option["label"]: selectButton(label)
                        )
                elif action:
                    pass
            except Exception as e:
                U.pprint(e)

# if st.button("Rerun"):
#     # __resetButtonState()
#     st.session_state.chatHistory = []
#     st.session_state.messages = []
#     st.rerun()