from typing import Any, List, Tuple
import gradio as gr
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import ChatOpenAI
from langchain_community.document_loaders import PyMuPDFLoader
import fitz
from PIL import Image
import os
import re
import openai

# MyApp class to handle the processes
class MyApp:
    def __init__(self) -> None:
        self.OPENAI_API_KEY: str = None  # Initialize with None
        self.chain = None
        self.chat_history: list = []
        self.documents = None
        self.file_name = None

    def set_api_key(self, api_key: str):
        self.OPENAI_API_KEY = api_key
        openai.api_key = api_key

    def process_file(self, file) -> Image.Image:
        loader = PyMuPDFLoader(file.name)
        self.documents = loader.load()
        self.file_name = os.path.basename(file.name)
        doc = fitz.open(file.name)
        page = doc[0]
        pix = page.get_pixmap(dpi=150)
        image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
        return image

    def build_chain(self, file) -> str:
        embeddings = OpenAIEmbeddings(openai_api_key=self.OPENAI_API_KEY)
        pdfsearch = Chroma.from_documents(
            self.documents,
            embeddings,
            collection_name=self.file_name,
        )
        self.chain = ConversationalRetrievalChain.from_llm(
            ChatOpenAI(temperature=0.0, openai_api_key=self.OPENAI_API_KEY),
            retriever=pdfsearch.as_retriever(search_kwargs={"k": 1}),
            return_source_documents=True,
        )
        return "Vector database built successfully!"

# Function to add text to chat history
def add_text(history: List[Tuple[str, str]], text: str) -> List[Tuple[str, str]]:
    if not text:
        raise gr.Error("Enter text")
    history.append((text, ""))
    return history

# Function to get response from the model
def get_response(history, query):
    if app.chain is None:
        raise gr.Error("The chain has not been built yet. Please ensure the vector database is built before querying.")
    
    try:
        result = app.chain.invoke(
            {"question": query, "chat_history": app.chat_history}
        )
        app.chat_history.append((query, result["answer"]))
        source_docs = result["source_documents"]
        source_texts = []
        for doc in source_docs:
            source_texts.append(f"Page {doc.metadata['page'] + 1}: {doc.page_content}")
        source_texts_str = "\n\n".join(source_texts)
        history[-1] = (history[-1][0], result["answer"])
        return history, source_texts_str
    except Exception as e:
        app.chat_history.append((query, "I have no information about it. Feed me knowledge, please!"))
        return history, f"I have no information about it. Feed me knowledge, please! Error: {str(e)}"

# Function to get response for the current RAG tab
def get_response_current(history, query):
    if app.chain is None:
        raise gr.Error("The chain has not been built yet. Please ensure the vector database is built before querying.")
    
    try:
        result = app.chain.invoke(
            {"question": query, "chat_history": app.chat_history}
        )
        app.chat_history.append((query, result["answer"]))
        source_docs = result["source_documents"]
        source_texts = []
        for doc in source_docs:
            source_texts.append(f"Page {doc.metadata['page'] + 1}: {doc.page_content}")
        source_texts_str = "\n\n".join(source_texts)
        history[-1] = (history[-1][0], result["answer"])
        return history, source_texts_str
    except Exception as e:
        app.chat_history.append((query, "I have no information about it. Feed me knowledge, please!"))
        return history, f"I have no information about it. Feed me knowledge, please! Error: {str(e)}"

# Function to render file
def render_file(file) -> Image.Image:
    doc = fitz.open(file.name)
    page = doc[0]
    pix = page.get_pixmap(dpi=150)
    image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    return image

# Function to purge chat and render first page of PDF
def purge_chat_and_render_first(file) -> Image.Image:
    app.chat_history = []
    doc = fitz.open(file.name)
    page = doc[0]
    pix = page.get_pixmap(dpi=150)
    image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    return image

# Function to refresh chat
def refresh_chat():
    app.chat_history = []
    return []

app = MyApp()

# Function to set API key
def set_api_key(api_key):
    app.set_api_key(api_key)
    # Pre-process the saved PDF file after setting the API key
    saved_file_path = "THEDIA1.pdf"
    with open(saved_file_path, 'rb') as saved_file:
        app.process_file(saved_file)
        app.build_chain(saved_file)
    return f"API Key set to {api_key[:4]}...{api_key[-4:]} and vector database built successfully!"

# Gradio interface
with gr.Blocks() as demo:
    title = "🧘‍♀️ Dialectical Behaviour Therapy"
    api_key_input = gr.Textbox(label="OpenAI API Key", type="password", placeholder="Enter your OpenAI API Key")
    api_key_btn = gr.Button("Set API Key")
    api_key_status = gr.Textbox(value="API Key status", interactive=False)

    api_key_btn.click(
        fn=set_api_key,
        inputs=[api_key_input],
        outputs=[api_key_status]
    )           

    with gr.Tab("Take a Dialectical Behaviour Therapy with Me"):
        with gr.Column():
            chatbot_current = gr.Chatbot(elem_id="chatbot_current")
            txt_current = gr.Textbox(
                show_label=False,
                placeholder="Enter text and press submit",
                scale=2
            )
            submit_btn_current = gr.Button("Submit", scale=1)
            refresh_btn_current = gr.Button("Refresh Chat", scale=1)
            source_texts_output_current = gr.Textbox(label="Source Texts", interactive=False)

            submit_btn_current.click(
                fn=add_text,
                inputs=[chatbot_current, txt_current],
                outputs=[chatbot_current],
                queue=False,
            ).success(
                fn=get_response_current, inputs=[chatbot_current, txt_current], outputs=[chatbot_current, source_texts_output_current]
            )

            refresh_btn_current.click(
                fn=refresh_chat,
                inputs=[],
                outputs=[chatbot_current],
            )

demo.queue()
demo.launch()