import gradio as gr
import torch
import os
import numpy as np
from groq import Groq
import spaces
from transformers import AutoModel, AutoTokenizer
from diffusers import StableDiffusion3Pipeline
from parler_tts import ParlerTTSForConditionalGeneration
import soundfile as sf
from langchain_groq import ChatGroq
from PIL import Image
from tavily import TavilyClient
from langchain.schema import AIMessage
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import RetrievalQA
from torchvision import transforms
import json
import pandas

# Initialize models and clients
MODEL = 'llama-3.1-70b-versatile'
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))

vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True,
                                      device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True)

tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1")
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1")

# Updated Image generation model
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

# Tavily Client for web search
tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API"))

# Function to play voice output
def play_voice_output(response):
    print("Executing play_voice_output function")
    description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
    input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda')
    prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda')
    generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
    audio_arr = generation.cpu().numpy().squeeze()
    sf.write("output.wav", audio_arr, tts_model.config.sampling_rate)
    return "output.wav"

# Function to classify user input using LLM
def classify_function(user_prompt):
    prompt = f"""
    You are a function classifier AI assistant. You are given a user input and you need to classify it into one of the following functions:

    - `image_generation`: If the user wants to generate an image.
    - `image_vqa`: If the user wants to ask questions about an image.
    - `document_qa`: If the user wants to ask questions about a document.
    - `text_to_text`: If the user wants a text-based response.

    Respond with a JSON object containing only the chosen function. For example:

    ```json
    {{"function": "image_generation"}}
    ```

    User input: {user_prompt}
    """

    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": prompt,
            }
        ],
        model="llama3-8b-8192",
    )

    try:
        response = json.loads(chat_completion.choices[0].message.content)
        function = response.get("function")
        return function
    except json.JSONDecodeError:
        print(f"Error decoding JSON: {chat_completion.choices[0].message.content}")
        return "text_to_text"  # Default to text-to-text if JSON parsing fails

# Document Question Answering Tool
class DocumentQuestionAnswering:
    def __init__(self, document):
        self.document = document
        self.qa_chain = self._setup_qa_chain()

    def _setup_qa_chain(self):
        print("Setting up DocumentQuestionAnswering tool")
        loader = TextLoader(self.document)
        documents = loader.load()
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
        texts = text_splitter.split_documents(documents)
        embeddings = HuggingFaceEmbeddings()
        db = FAISS.from_documents(texts, embeddings)
        retriever = db.as_retriever()
        qa_chain = RetrievalQA.from_chain_type(
            llm=ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY")),
            chain_type="stuff",
            retriever=retriever,
        )
        return qa_chain

    def run(self, query: str) -> str:
        print("Executing DocumentQuestionAnswering tool")
        response = self.qa_chain.run(query)
        return str(response)

# Function to handle different input types and choose the right pipeline
def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None):
    print(f"Handling input: {user_prompt}")

    # Initialize the LLM
    llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY"))

    # Handle voice-only mode
    if audio:
        print("Processing audio input")
        transcription = client.audio.transcriptions.create(
            file=(audio.name, audio.read()),
            model="whisper-large-v3"
        )
        user_prompt = transcription.text
        response = llm.invoke(query=user_prompt)
        audio_output = play_voice_output(response)
        return "Response generated.", audio_output

    # Handle websearch mode
    if websearch:
        print("Executing Web Search")
        answer = tavily_client.qna_search(query=user_prompt)
        return answer, None

    # Handle cases with only image or document input
    if user_prompt is None or user_prompt.strip() == "":
        if image:
            user_prompt = "Describe this image"
        elif document:
            user_prompt = "Summarize this document"

    # Classify user input using LLM
    function = classify_function(user_prompt)

    # Handle different functions
    if function == "image_generation":
        print("Executing Image Generation")
        image = pipe(
            user_prompt,
            negative_prompt="",
            num_inference_steps=15,
            guidance_scale=7.0,
        ).images[0]
        image.save("output.jpg")
        return "output.jpg", None

    elif function == "image_vqa":
        print("Executing Image Description")
        if image:
            print("1")
            image = Image.open(image).convert('RGB')
            print("2")
    
            # Add preprocessing steps here (see examples above)
            preprocess = transforms.Compose([
                transforms.Resize((512, 512)),  # Example size, replace with the correct one
                transforms.ToTensor(),
            ])
            image = preprocess(image)
            image = image.unsqueeze(0)  # Add batch dimension
            image = image.to(torch.float32)  # Ensure correct data type
    
            print("3")
            messages = [{"role": "user", "content": user_prompt}]
            print("4")
            response,ctxt = vqa_model.chat(image=image, msgs=messages, tokenizer=tokenizer, context=None, temperature=0.5)
            print("5")
            return response, None
        else:
            return "Please upload an imagee.", None

    elif function == "document_qa":
        print("Executing Document Summarization")
        if document:
            document_qa = DocumentQuestionAnswering(document)
            response = document_qa.run(user_prompt)
            return response, None
        else:
            return "Please upload a documentt.", None

    else:  # function == "text_to_text"
        print("Executing Text-to-Text")
        response = llm.invoke(query=user_prompt)
        return response, None

# Main interface function
@spaces.GPU(duration=120)
def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None):
    print("Starting main_interface function")
    vqa_model.to(device='cuda', dtype=torch.bfloat16)
    tts_model.to("cuda")
    pipe.to("cuda")

    print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}")

    try:
        response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch, document=document)
        print("handle_input function executed successfully")
    except Exception as e:
        print(f"Error in handle_input: {e}")
        response = "Error occurred during processing."

    return response

def create_ui():
    with gr.Blocks(css="""
        /* Overall Styling */
        body {
            font-family: 'Poppins', sans-serif;
            background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
            margin: 0;
            padding: 0;
            color: #333;
        }

        /* Title Styling */
        .gradio-container h1 {
            text-align: center;
            padding: 20px 0;
            background: linear-gradient(45deg, #007bff, #00c6ff);
            color: white;
            font-size: 2.5em;
            font-weight: bold;
            letter-spacing: 1px;
            text-transform: uppercase;
            margin: 0;
            box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2);
        }

        /* Input Area Styling */
        .gradio-container .gr-row {
            display: flex;
            justify-content: space-around;
            align-items: center;
            padding: 20px;
            background-color: white;
            border-radius: 10px;
            box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.1);
            margin-bottom: 20px;
        }

        .gradio-container .gr-column {
            flex: 1;
            margin: 0 10px;
        }

        /* Textbox Styling */
        .gradio-container textarea {
            width: calc(100% - 20px);
            padding: 15px;
            border: 2px solid #007bff;
            border-radius: 8px;
            font-size: 1.1em;
            transition: border-color 0.3s, box-shadow 0.3s;
        }

        .gradio-container textarea:focus {
            border-color: #00c6ff;
            box-shadow: 0px 0px 8px rgba(0, 198, 255, 0.5);
            outline: none;
        }

        /* Button Styling */
        .gradio-container button {
            background: linear-gradient(45deg, #007bff, #00c6ff);
            color: white;
            padding: 15px 25px;
            border: none;
            border-radius: 8px;
            cursor: pointer;
            font-size: 1.2em;
            font-weight: bold;
            transition: background 0.3s, transform 0.3s;
            box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
        }

        .gradio-container button:hover {
            background: linear-gradient(45deg, #0056b3, #009bff);
            transform: translateY(-3px);
        }

        .gradio-container button:active {
            transform: translateY(0);
        }

        /* Output Area Styling */
        .gradio-container .output-area {
            padding: 20px;
            text-align: center;
            background-color: #f7f9fc;
            border-radius: 10px;
            box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.1);
            margin-top: 20px;
        }

        /* Image Styling */
        .gradio-container img {
            max-width: 100%;
            height: auto;
            border-radius: 10px;
            box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
            transition: transform 0.3s, box-shadow 0.3s;
        }

        .gradio-container img:hover {
            transform: scale(1.05);
            box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.2);
        }

        /* Checkbox Styling */
        .gradio-container input[type="checkbox"] {
            width: 20px;
            height: 20px;
            cursor: pointer;
            accent-color: #007bff;
            transition: transform 0.3s;
        }

        .gradio-container input[type="checkbox"]:checked {
            transform: scale(1.2);
        }

        /* Audio and Document Upload Styling */
        .gradio-container .gr-file-upload input[type="file"] {
            width: 100%;
            padding: 10px;
            border: 2px solid #007bff;
            border-radius: 8px;
            cursor: pointer;
            background-color: white;
            transition: border-color 0.3s, background-color 0.3s;
        }

        .gradio-container .gr-file-upload input[type="file"]:hover {
            border-color: #00c6ff;
            background-color: #f0f8ff;
        }

        /* Advanced Tooltip Styling */
        .gradio-container .gr-tooltip {
            position: relative;
            display: inline-block;
            cursor: pointer;
        }

        .gradio-container .gr-tooltip .tooltiptext {
            visibility: hidden;
            width: 200px;
            background-color: black;
            color: #fff;
            text-align: center;
            border-radius: 6px;
            padding: 5px;
            position: absolute;
            z-index: 1;
            bottom: 125%;
            left: 50%;
            margin-left: -100px;
            opacity: 0;
            transition: opacity 0.3s;
        }

        .gradio-container .gr-tooltip:hover .tooltiptext {
            visibility: visible;
            opacity: 1;
        }

        /* Footer Styling */
        .gradio-container footer {
            text-align: center;
            padding: 10px;
            background: #007bff;
            color: white;
            font-size: 0.9em;
            border-radius: 0 0 10px 10px;
            box-shadow: 0px -2px 8px rgba(0, 0, 0, 0.1);
        }

    """) as demo:
        gr.Markdown("# AI Assistant")
        with gr.Row():
            with gr.Column(scale=2):
                user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1)
            with gr.Column(scale=1):
                image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon")
                audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon")
                document_input = gr.File(type="filepath", label="Upload a document", elem_id="document-icon")
                voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode")
                websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode")
            with gr.Column(scale=1):
                submit = gr.Button("Submit")

        output_label = gr.Label(label="Output")
        audio_output = gr.Audio(label="Audio Output", visible=False)

        submit.click(
            fn=main_interface,
            inputs=[user_prompt, image_input, audio_input, voice_only_mode, websearch_mode, document_input],
            outputs=[output_label, audio_output]
        )

        voice_only_mode.change(
            lambda x: gr.update(visible=not x),
            inputs=voice_only_mode,
            outputs=[user_prompt, image_input, websearch_mode, document_input, submit]
        )
        voice_only_mode.change(
            lambda x: gr.update(visible=x),
            inputs=voice_only_mode,
            outputs=[audio_input]
        )

    return demo

# Launch the UI
demo = create_ui()
demo.launch()