#############################################################################################################################
# Filename   : app.py
# Description: A Streamlit application to showcase the importance of Responsible AI in LLMs.
# Author     : Georgios Ioannou
#
# TODO: Add code for Google Gemma 7b and 7b-it.
# TODO: Write code documentation.
# Copyright © 2024 by Georgios Ioannou
#############################################################################################################################
# Import libraries.

import os  # Load environment variable(s).
import requests  # Send HTTP GET request to Hugging Face models for inference.
import streamlit as st  # Build the GUI of the application.
import streamlit.components.v1 as components

from dataclasses import dataclass
from dotenv import find_dotenv, load_dotenv  # Read local .env file.
from langchain.callbacks import get_openai_callback
from langchain.chains import ConversationChain
from langchain.llms import OpenAI
from policies import complex_policy, simple_policy
from transformers import pipeline  # Access to Hugging Face models.
from typing import Literal


#############################################################################################################################
# Load environment variable(s).

# HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")


#############################################################################################################################


@dataclass
class Message:
    """Class for keeping track of a chat message."""

    origin: Literal["human", "ai"]
    message: str


#############################################################################################################################
# Remeber (Keep history) user's message in each rendering of the application.


def initialize_session_state():
    if "history" not in st.session_state:
        st.session_state.history = []
    if "token_count" not in st.session_state:
        st.session_state.token_count = 0
    if "policy" not in st.session_state:
        st.session_state.policy = "No Policy"
    if "is_gemma_model" not in st.session_state:
        st.session_state.is_gemma_model = False
    if "selected_ai_icon" not in st.session_state:
        st.session_state.selected_ai_icon = "a1.png"
    if "selected_user_icon" not in st.session_state:
        st.session_state.selected_user_icon = "man.png"


#############################################################################################################################


def on_click_callback():
    with get_openai_callback() as cb:
        human_prompt = st.session_state.human_prompt

        # Policy + User Query Text.

        role = "User"  # This can be an Agent.

        task = f"""
        Task: Check if there is unsafe content in
        '{role}' messages in conversations
        according our safety policy with the below categories.
        """

        output_format = f"""
        Provide your safety assessment for '{role}' in the above conversation:
        - First line must read 'safe' or 'unsafe'.
        - If unsafe, a second line must include a comma-separated list of violated categories.
        """

        query = human_prompt

        conversation = f"""
        <BEGIN CONVERSATION>
        User: {query}
        <END CONVERSATION>
        """

        if st.session_state.policy == "Simple Policy":
            prompt = f"""
            {task}
            {simple_policy}
            {conversation}
            {output_format}
            """
        elif st.session_state.policy == "Complex Policy":
            prompt = f"""
            {task}
            {complex_policy}
            {conversation}
            {output_format}
            """
        elif st.session_state.policy == "No Policy":
            prompt = human_prompt

        # Getting the llm response for safety check 1.
        # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b"
        if st.session_state.is_gemma_model:
            pass
        else:
            llm_response_safety_check_1 = st.session_state.conversation.run(prompt)
            st.session_state.history.append(Message("human", human_prompt))
            st.session_state.token_count += cb.total_tokens

        # Checking if response is safe. Safety Check 1. Checking what goes in (user input).
        if (
            "unsafe" in llm_response_safety_check_1.lower()
        ):  # If respone is unsafe return unsafe.
            st.session_state.history.append(Message("ai", llm_response_safety_check_1))
            return
        else:  # If respone is safe answer the question.
            if st.session_state.is_gemma_model:
                pass
            else:
                conversation_chain = ConversationChain(
                    llm=OpenAI(
                        temperature=0.2,
                        openai_api_key=OPENAI_API_KEY,
                        model_name=st.session_state.model,
                    ),
                )
                llm_response = conversation_chain.run(human_prompt)
                # st.session_state.history.append(Message("ai", llm_response))
                st.session_state.token_count += cb.total_tokens

        # Policy + LLM Response.
        query = llm_response

        conversation = f"""
        <BEGIN CONVERSATION>
        User: {query}
        <END CONVERSATION>
        """

        if st.session_state.policy == "Simple Policy":
            prompt = f"""
            {task}
            {simple_policy}
            {conversation}
            {output_format}
            """
        elif st.session_state.policy == "Complex Policy":
            prompt = f"""
            {task}
            {complex_policy}
            {conversation}
            {output_format}
            """
        elif st.session_state.policy == "No Policy":
            prompt = llm_response

        # Getting the llm response for safety check 2.
        # "https://api-inference.huggingface.co/models/meta-llama/LlamaGuard-7b"
        if st.session_state.is_gemma_model:
            pass
        else:
            llm_response_safety_check_2 = st.session_state.conversation.run(prompt)
            st.session_state.token_count += cb.total_tokens

        # Checking if response is safe. Safety Check 2. Checking what goes out (llm output).
        if (
            "unsafe" in llm_response_safety_check_2.lower()
        ):  # If respone is unsafe return.
            st.session_state.history.append(
                Message(
                    "ai",
                    "THIS FROM THE AUTHOR OF THE CODE: LLM WANTED TO RESPOND UNSAFELY!",
                )
            )
        else:
            st.session_state.history.append(Message("ai", llm_response))


#############################################################################################################################
# Function to apply local CSS.


def local_css(file_name):
    with open(file_name) as f:
        st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)


#############################################################################################################################


# Main function to create the Streamlit web application.


def main():
    # try:
    initialize_session_state()

    # Page title and favicon.
    st.set_page_config(page_title="Responsible AI", page_icon="⚖️")

    # Load CSS.
    local_css("./static/styles/styles.css")

    # Title.
    title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">
                Responsible AI</h1>"""
    st.markdown(title, unsafe_allow_html=True)

    # Subtitle 1.
    title = f"""<h3 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">
                Showcase the importance of Responsible AI in LLMs</h3>"""
    st.markdown(title, unsafe_allow_html=True)

    # Subtitle 2.
    title = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: 0rem">
                CUNY Tech Prep Tutorial 6</h2>"""
    st.markdown(title, unsafe_allow_html=True)

    # Image.
    image = "./static/ctp.png"
    left_co, cent_co, last_co = st.columns(3)
    with cent_co:
        st.image(image=image)

    # Sidebar dropdown menu for Models.
    models = [
        "gpt-4-turbo",
        "gpt-4",
        "gpt-3.5-turbo",
        "gpt-3.5-turbo-instruct",
        "gemma-7b",
        "gemma-7b-it",
    ]
    selected_model = st.sidebar.selectbox("Select Model:", models)
    st.sidebar.write(f"Current Model: {selected_model}")

    if selected_model == "gpt-4-turbo":
        st.session_state.model = "gpt-4-turbo"
    elif selected_model == "gpt-4":
        st.session_state.model = "gpt-4"
    elif selected_model == "gpt-3.5-turbo":
        st.session_state.model = "gpt-3.5-turbo"
    elif selected_model == "gpt-3.5-turbo-instruct":
        st.session_state.model = "gpt-3.5-turbo-instruct"
    elif selected_model == "gemma-7b":
        st.session_state.model = "gemma-7b"
    elif selected_model == "gemma-7b-it":
        st.session_state.model = "gemma-7b-it"

    if "gpt" in st.session_state.model:
        st.session_state.conversation = ConversationChain(
            llm=OpenAI(
                temperature=0.2,
                openai_api_key=OPENAI_API_KEY,
                model_name=st.session_state.model,
            ),
        )
    elif "gemma" in st.session_state.model:
        # Load model from Hugging Face.
        st.session_state.is_gemma_model = True
        pass

    # Sidebar dropdown menu for Policies.
    policies = ["No Policy", "Complex Policy", "Simple Policy"]
    selected_policy = st.sidebar.selectbox("Select Policy:", policies)
    st.sidebar.write(f"Current Policy: {selected_policy}")

    if selected_policy == "No Policy":
        st.session_state.policy = "No Policy"
    elif selected_policy == "Complex Policy":
        st.session_state.policy = "Complex Policy"
    elif selected_policy == "Simple Policy":
        st.session_state.policy = "Simple Policy"

    # Sidebar dropdown menu for AI Icons.
    ai_icons = ["AI 1", "AI 2"]
    selected_ai_icon = st.sidebar.selectbox("AI Icon:", ai_icons)
    st.sidebar.write(f"Current AI Icon: {selected_ai_icon}")

    if selected_ai_icon == "AI 1":
        st.session_state.selected_ai_icon = "ai1.png"
    elif selected_ai_icon == "AI 2":
        st.session_state.selected_ai_icon = "ai2.png"

    # Sidebar dropdown menu for User Icons.
    user_icons = ["Man", "Woman"]
    selected_user_icon = st.sidebar.selectbox("User Icon:", user_icons)
    st.sidebar.write(f"Current User Icon: {selected_user_icon}")

    if selected_user_icon == "Man":
        st.session_state.selected_user_icon = "man.png"
    elif selected_user_icon == "Woman":
        st.session_state.selected_user_icon = "woman.png"

    # Placeholder for the chat messages.
    chat_placeholder = st.container()
    # Placeholder for the user input.
    prompt_placeholder = st.form("chat-form")
    token_placeholder = st.empty()

    with chat_placeholder:
        for chat in st.session_state.history:
            div = f"""
    <div class="chat-row 
        {'' if chat.origin == 'ai' else 'row-reverse'}">
        <img class="chat-icon" src="app/static/{
            st.session_state.selected_ai_icon if chat.origin == 'ai' 
                        else st.session_state.selected_user_icon}"
            width=32 height=32>
        <div class="chat-bubble
        {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
            &#8203;{chat.message}
        </div>
    </div>
            """
            st.markdown(div, unsafe_allow_html=True)

        for _ in range(3):
            st.markdown("")

    # User prompt.
    with prompt_placeholder:
        st.markdown("**Chat**")
        cols = st.columns((6, 1))

        # Large text input in the left column.
        cols[0].text_input(
            "Chat",
            placeholder="What is your question?",
            label_visibility="collapsed",
            key="human_prompt",
        )
        # Red button in the right column.
        cols[1].form_submit_button(
            "Submit",
            type="primary",
            on_click=on_click_callback,
        )

    token_placeholder.caption(
        f"""
Used {st.session_state.token_count} tokens \n
"""
    )

    # GitHub repository of author.

    st.markdown(
        f"""
            <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our
            <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b>
            </p>
    """,
        unsafe_allow_html=True,
    )

    # Use the Enter key in the keyborad to click on the Submit button.
    components.html(
        """
<script>
const streamlitDoc = window.parent.document;

const buttons = Array.from(
    streamlitDoc.querySelectorAll('.stButton > button')
);
const submitButton = buttons.find(
    el => el.innerText === 'Submit'
);

streamlitDoc.addEventListener('keydown', function(e) {
    switch (e.key) {
        case 'Enter':
            submitButton.click();
            break;
    }
});
</script>
""",
        height=0,
        width=0,
    )


#############################################################################################################################


if __name__ == "__main__":
    main()