#############################################################################################################################
# Filename   : app.py
# Description: A Streamlit application to utilize five models back to back
#              Models used:
#                           1. Visual Question Answering (VQA).
#                           2. Fill-Mask.
#                           3. Text2text Generation.
#                           4. Text Generation.
#                           5. Topic.
# Author     : Georgios Ioannou
#
# Copyright © 2024 by Georgios Ioannou
#############################################################################################################################

# Import libraries.

import streamlit as st  # Build the GUI of the application.
import torch  # Load Salesforce/blip model(s) on GPU.

from bertopic import BERTopic  # Topic model inference.
from PIL import Image  # Open and identify a given image file.
from transformers import (
    pipeline,
    BlipProcessor,
    BlipForQuestionAnswering,
)  # VQA model inference.

#############################################################################################################################

# Function to apply local CSS.


def local_css(file_name):
    with open(file_name) as f:
        st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)


#############################################################################################################################

# Model 1.
# Model 1 gets input from the user.
# User -> Model 1

# Load the Visual Question Answering (VQA) model directly.
# Using transformers.


@st.cache_resource
def load_model_blip():
    blip_processor_base = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
    blip_model_base = BlipForQuestionAnswering.from_pretrained(
        "Salesforce/blip-vqa-base"
    )

    # Backup model.
    # blip_processor_large  = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
    # blip_model_large  = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
    # return blip_processor_large, blip_model_large

    return blip_processor_base, blip_model_base


# General function for any Salesforce/blip model(s).
# VQA model.


def generate_answer_blip(processor, model, image, question):
    # Prepare image + question.

    inputs = processor(images=image, text=question, return_tensors="pt")

    generated_ids = model.generate(**inputs, max_length=50)

    generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)

    return generated_answer


# Generate answer from the Salesforce/blip model(s).
# VQA model.


@st.cache_resource
def generate_answer(image, question):
    answer_blip_base = generate_answer_blip(
        processor=blip_processor_base,
        model=blip_model_base,
        image=image,
        question=question,
    )

    # answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question)
    # return answer_blip_large

    return answer_blip_base


#############################################################################################################################

# Model 2.
# Model 2 gets input from Model 1.
# User -> Model 1 -> Model 2


@st.cache_resource
def load_model_fill_mask():
    return pipeline(task="fill-mask", model="bert-base-uncased")


#############################################################################################################################

# Model 3.
# Model 3 gets input from Model 2.
# User -> Model 1 -> Model 2 -> Model 3


@st.cache_resource
def load_model_text2text_generation():
    return pipeline(
        task="text2text-generation", model="facebook/blenderbot-400M-distill"
    )


#############################################################################################################################

# Model 4.
# Model 4 gets input from Model 3.
# User -> Model 1 -> Model 2 -> Model 3 -> Model 4


@st.cache_resource
def load_model_fill_text_generation():
    return pipeline(task="text-generation", model="gpt2")


#############################################################################################################################

# Model 5.
# Model 5 gets input from Model 4.
# User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5


@st.cache_resource
def load_model_bertopic1():
    return BERTopic.load(path="davanstrien/chat_topics")


@st.cache_resource
def load_model_bertopic2():
    return BERTopic.load(path="MaartenGr/BERTopic_ArXiv")


#############################################################################################################################
# Page title and favicon.

st.set_page_config(page_title="Visual Question Answering", page_icon="❓")

#############################################################################################################################

# Load the Salesforce/blip model directly.

if torch.cuda.is_available():
    device = torch.device("cuda")
# elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
#     device = torch.device("mps")
else:
    device = torch.device("cpu")

blip_processor_base, blip_model_base = load_model_blip()
blip_model_base.to(device)

#############################################################################################################################
# Main function to create the Streamlit web application.
#
# 5 MODEL INFERENCES.
# User Input = Image + Question About The Image.
# User -> Model 1 -> Model 2 -> Model 3 -> Model 4 -> Model 5


def main():
    try:
        #####################################################################################################################

        # Load CSS.

        local_css("styles/style.css")

        #####################################################################################################################

        # Title.

        title = f"""<h1 align="center" style="font-family: monospace; font-size: 2.1rem; margin-top: -4rem">
                    Georgios Ioannou's Visual Question Answering</h1>"""
        st.markdown(title, unsafe_allow_html=True)
        # st.title("ChefBot - Automated Recipe Assistant")

        #####################################################################################################################

        # Subtitle.

        subtitle = f"""<h2 align="center" style="font-family: monospace; font-size: 1.5rem; margin-top: -2rem">
                    CUNY Tech Prep Tutorial 4</h2>"""
        st.markdown(subtitle, unsafe_allow_html=True)

        #####################################################################################################################

        # Image.

        image = "./ctp.png"
        left_co, cent_co, last_co = st.columns(3)
        with cent_co:
            st.image(image=image)

        #####################################################################################################################

        # User input (Image).
        image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])

        if image is not None:
            bytes_data = image.getvalue()

            with open(image.name, "wb") as file:

                file.write(bytes_data)
                st.image(image, caption="Uploaded Image.", use_column_width=True)
                raw_image = Image.open(image.name).convert("RGB")

                # User input (Question).
                question = st.text_input("What's your question?")

                #############################################################################################################

                if question != "":
                    # Model 1.
                    with st.spinner(
                        text="VQA inference..."
                    ):  # Spinner to keep the application interactive.
                        # Model inference.

                        answer = generate_answer(raw_image, question)[0]
                    st.success(f"VQA: {answer}")

                    bbu_pipeline = load_model_fill_mask()
                    text = (
                        "I love " + answer + " and I would like to know how to [MASK]."
                    )

                    #########################################################################################################

                    # Model 2.
                    with st.spinner(
                        text="Fill-Mask inference..."
                    ):  # Spinner to keep the application interactive.
                        # Model inference.
                        bbu_pipeline_output = bbu_pipeline(text)
                    bbu_output = bbu_pipeline_output[0]["sequence"]
                    st.success(f"Fill-Mask: {bbu_output}")

                    facebook_pipeline = load_model_text2text_generation()
                    utterance = bbu_output

                    #########################################################################################################

                    # Model 3.
                    with st.spinner(
                        text="Text2text Generation inference..."
                    ):  # Spinner to keep the application interactive.
                        # Model inference.
                        facebook_pipeline_output = facebook_pipeline(utterance)
                    facebook_output = facebook_pipeline_output[0]["generated_text"]
                    st.success(f"Text2text Generation: {facebook_output}")

                    gpt2_pipeline = load_model_fill_text_generation()

                    #########################################################################################################

                    # Model 4.
                    with st.spinner(
                        text="Fill Text Generation inference..."
                    ):  # Spinner to keep the application interactive.
                        # Model inference.
                        gpt2_pipeline_output = gpt2_pipeline(facebook_output)
                    gpt2_output = gpt2_pipeline_output[0]["generated_text"]
                    st.success(f"Fill Text Generation: {gpt2_output}")

                    #########################################################################################################

                    # Model 5.
                    topic_model_1 = load_model_bertopic1()
                    topic, prob = topic_model_1.transform(gpt2_pipeline_output)
                    topic_model_1_output = topic_model_1.get_topic_info(topic[0])[
                        "Representation"
                    ][0]
                    st.success(
                        f"Topic(s) from davanstrien/chat_topics: {topic_model_1_output}"
                    )

                    topic_model_2 = load_model_bertopic2()
                    topic, prob = topic_model_2.transform(gpt2_pipeline_output)
                    topic_model_2_output = topic_model_2.get_topic_info(topic[0])[
                        "Representation"
                    ][0]
                    st.success(
                        f"Topic(s) from MaartenGr/BERTopic_ArXiv: {topic_model_1_output}"
                    )
    except Exception as e:
        # General exception/error handling.

        st.error(e)

    # GitHub repository of author.

    st.markdown(
        f"""
            <p align="center" style="font-family: monospace; color: #FAF9F6; font-size: 1rem;"><b> Check out our
            <a href="https://github.com/GeorgiosIoannouCoder/" style="color: #FAF9F6;"> GitHub repository</a></b>
            </p>
    """,
        unsafe_allow_html=True,
    )


#############################################################################################################################
if __name__ == "__main__":
    main()