import os
import spaces

import gradio as gr
import torch
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import (
    process_images,
    process_queries,
)
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import re
import time
from PIL import Image
import torch
import subprocess
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)






@spaces.GPU
def model_inference(
    images, text,
):
    
    # print(type(images))
    # print(images[0])
    # images = Image.open(images[0][0])
    # print(images)
    # print(type(images))
    images = [{"type": "image", "image": Image.open(image[0])} for image in images]
    images.append({"type": "text", "text": text})
    print(images)
    # model = Qwen2VLForConditionalGeneration.from_pretrained(
    # "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
    # )

    #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-2B-Instruct",
        #attn_implementation="flash_attention_2", #doesn't work on zerogpu WTF?!
        trust_remote_code=True, 
        torch_dtype="auto").cuda().eval()

    # default processer
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")

    # The default range for the number of visual tokens per image in the model is 4-16384. You can set min_pixels and max_pixels according to your needs, such as a token count range of 256-1280, to balance speed and memory usage.
    # min_pixels = 256*28*28
    # max_pixels = 1280*28*28
    # processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

    messages = [
        {
            "role": "user",
            "content": images,
        }
    ]

    # Preparation for inference
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens=512)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    del model
    del processor
    torch.cuda.empty_cache()
    return output_text[0]



@spaces.GPU
def search(query: str, ds, images, k):

    # Load colpali model
    model_name = "vidore/colpali-v1.2"
    token = os.environ.get("HF_TOKEN")
    model = ColPali.from_pretrained(
        "vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token = token).eval()

    model.load_adapter(model_name)
    model = model.eval()
    processor = AutoProcessor.from_pretrained(model_name, token = token)

    mock_image = Image.new("RGB", (448, 448), (255, 255, 255))

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)
        
    qs = []
    with torch.no_grad():
        batch_query = process_queries(processor, [query], mock_image)
        batch_query = {k: v.to(device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs, ds)

    top_k_indices = scores.argsort(axis=1)[0][-k:][::-1]

    results = []
    for idx in top_k_indices:
        results.append((images[idx])) #, f"Page {idx}"
    del model
    del processor
    torch.cuda.empty_cache()
    print("done")
    return results


def index(files, ds):
    print("Converting files")
    images = convert_files(files)
    print(f"Files converted with {len(images)} images.")
    return index_gpu(images, ds)
    


def convert_files(files):
    images = []
    for f in files:
        images.extend(convert_from_path(f, thread_count=4))

    if len(images) >= 150:
        raise gr.Error("The number of images in the dataset should be less than 150.")
    return images


@spaces.GPU
def index_gpu(images, ds):
    """Example script to run inference with ColPali"""
        # Load colpali model
    model_name = "vidore/colpali-v1.2"
    token = os.environ.get("HF_TOKEN")
    model = ColPali.from_pretrained(
        "vidore/colpaligemma-3b-pt-448-base", torch_dtype=torch.bfloat16, device_map="cuda", token = token).eval()

    model.load_adapter(model_name)
    model = model.eval()
    processor = AutoProcessor.from_pretrained(model_name, token = token)

    mock_image = Image.new("RGB", (448, 448), (255, 255, 255))
    # run inference - docs
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_images(processor, x),
    )

    
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    if device != model.device:
        model.to(device)
        
          
    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    del model
    del processor
    torch.cuda.empty_cache()
    print("done")
    return f"Uploaded and converted {len(images)} pages", ds, images


def get_example():
    return [
        [["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quels sont les 4 axes majeurs des achats?"],
        [["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "Quelles sont les actions entreprise en Afrique du Sud?"],
        [["RAPPORT_DEVELOPPEMENT_DURABLE_2019.pdf"], "fais moi un tableau markdown de la répartition homme femme"],
        ]

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# ColPali + Qwen2VL 2B: Efficient Document Retrieval with Vision Language Models 📚")
    

    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown("## 1️⃣ Upload PDFs")
            file = gr.File(file_types=["pdf"], file_count="multiple", label="Upload PDFs")

            
            message = gr.Textbox("Files not yet uploaded", label="Status")
            convert_button = gr.Button("🔄 Index documents")
            embeds = gr.State(value=[])
            imgs = gr.State(value=[])
            img_chunk = gr.State(value=[])

        with gr.Column(scale=3):
            gr.Markdown("## 2️⃣ Search")
            query = gr.Textbox(placeholder="Enter your query here", label="Query")
            k = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results", value=1)
            search_button = gr.Button("🔍 Search", variant="primary")
    
    with gr.Row():
        gr.Examples(
            examples=get_example(),
            inputs=[file, query],
        )

    # Define the actions
    
    output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)

    convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
    search_button.click(search, inputs=[query, embeds, imgs, k], outputs=[output_gallery])

    answer_button = gr.Button("Answer", variant="primary")
    output = gr.Markdown(label="Output")
    answer_button.click(model_inference, inputs=[output_gallery, query], outputs=output)

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True)