from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
import gradio as gr
import os

from .common_gui import get_folder_path, scriptdir, list_dirs
from .custom_logging import setup_logging

# Set up logging
log = setup_logging()


def load_model():
    # Set the device to GPU if available, otherwise use CPU
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Initialize the BLIP2 processor
    processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")

    # Initialize the BLIP2 model
    model = Blip2ForConditionalGeneration.from_pretrained(
        "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
    )

    # Move the model to the specified device
    model.to(device)

    return processor, model, device


def get_images_in_directory(directory_path):
    """
    Returns a list of image file paths found in the provided directory path.

    Parameters:
    - directory_path: A string representing the path to the directory to search for images.

    Returns:
    - A list of strings, where each string is the full path to an image file found in the specified directory.
    """
    import os

    # List of common image file extensions to look for
    image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"]

    # Generate a list of image file paths in the directory
    image_files = [
        # constructs the full path to the file
        os.path.join(directory_path, file)
        # lists all files and directories in the given path
        for file in os.listdir(directory_path)
        # gets the file extension in lowercase
        if os.path.splitext(file)[1].lower() in image_extensions
    ]

    # Return the list of image file paths
    return image_files


def generate_caption(
    file_list,
    processor,
    model,
    device,
    caption_file_ext=".txt",
    num_beams=5,
    repetition_penalty=1.5,
    length_penalty=1.2,
    max_new_tokens=40,
    min_new_tokens=20,
    do_sample=True,
    temperature=1.0,
    top_p=0.0,
):
    """
    Fetches and processes each image in file_list, generates captions based on the image, and writes the generated captions to a file.

    Parameters:
    - file_list: A list of file paths pointing to the images to be captioned.
    - processor: The preprocessor for the BLIP2 model.
    - model: The BLIP2 model to be used for generating captions.
    - device: The device on which the computation is performed.
    - extension: The extension for the output text files.
    - num_beams: Number of beams for beam search. Default: 5.
    - repetition_penalty: Penalty for repeating tokens. Default: 1.5.
    - length_penalty: Penalty for sentence length. Default: 1.2.
    - max_new_tokens: Maximum number of new tokens to generate. Default: 40.
    - min_new_tokens: Minimum number of new tokens to generate. Default: 20.
    """
    for file_path in file_list:
        image = Image.open(file_path)

        inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)

        if top_p == 0.0:
            generated_ids = model.generate(
                **inputs,
                num_beams=num_beams,
                repetition_penalty=repetition_penalty,
                length_penalty=length_penalty,
                max_new_tokens=max_new_tokens,
                min_new_tokens=min_new_tokens,
            )
        else:
            generated_ids = model.generate(
                **inputs,
                do_sample=do_sample,
                top_p=top_p,
                max_new_tokens=max_new_tokens,
                min_new_tokens=min_new_tokens,
                temperature=temperature,
            )

        generated_text = processor.batch_decode(
            generated_ids, skip_special_tokens=True
        )[0].strip()

        # Construct the output file path by replacing the original file extension with the specified extension
        output_file_path = os.path.splitext(file_path)[0] + caption_file_ext

        # Write the generated text to the output file
        with open(output_file_path, "w", encoding="utf-8") as output_file:
            output_file.write(generated_text)

        # Log the image file path with a message about the fact that the caption was generated
        log.info(f"{file_path} caption was generated")


def caption_images_beam_search(
    directory_path,
    num_beams,
    repetition_penalty,
    length_penalty,
    min_new_tokens,
    max_new_tokens,
    caption_file_ext,
):
    """
    Captions all images in the specified directory using the provided prompt.

    Parameters:
    - directory_path: A string representing the path to the directory containing the images to be captioned.
    """
    log.info("BLIP2 captionning beam...")

    if not os.path.isdir(directory_path):
        log.error(f"Directory {directory_path} does not exist.")
        return

    processor, model, device = load_model()
    image_files = get_images_in_directory(directory_path)
    generate_caption(
        file_list=image_files,
        processor=processor,
        model=model,
        device=device,
        num_beams=int(num_beams),
        repetition_penalty=float(repetition_penalty),
        length_penalty=length_penalty,
        min_new_tokens=int(min_new_tokens),
        max_new_tokens=int(max_new_tokens),
        caption_file_ext=caption_file_ext,
    )


def caption_images_nucleus(
    directory_path,
    do_sample,
    temperature,
    top_p,
    min_new_tokens,
    max_new_tokens,
    caption_file_ext,
):
    """
    Captions all images in the specified directory using the provided prompt.

    Parameters:
    - directory_path: A string representing the path to the directory containing the images to be captioned.
    """
    log.info("BLIP2 captionning nucleus...")

    if not os.path.isdir(directory_path):
        log.error(f"Directory {directory_path} does not exist.")
        return

    processor, model, device = load_model()
    image_files = get_images_in_directory(directory_path)
    generate_caption(
        file_list=image_files,
        processor=processor,
        model=model,
        device=device,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,
        min_new_tokens=int(min_new_tokens),
        max_new_tokens=int(max_new_tokens),
        caption_file_ext=caption_file_ext,
    )


def gradio_blip2_caption_gui_tab(headless=False, directory_path=None):
    from .common_gui import create_refresh_button

    directory_path = (
        directory_path
        if directory_path is not None
        else os.path.join(scriptdir, "data")
    )
    current_train_dir = directory_path

    def list_train_dirs(path):
        nonlocal current_train_dir
        current_train_dir = path
        return list(list_dirs(path))

    with gr.Tab("BLIP2 Captioning"):
        gr.Markdown(
            "This utility uses BLIP2 to caption files for each image in a folder."
        )

        with gr.Group(), gr.Row():
            directory_path_dir = gr.Dropdown(
                label="Image folder to caption (containing the images to caption)",
                choices=[""] + list_train_dirs(directory_path),
                value="",
                interactive=True,
                allow_custom_value=True,
            )
            create_refresh_button(
                directory_path_dir,
                lambda: None,
                lambda: {"choices": list_train_dirs(current_train_dir)},
                "open_folder_small",
            )
            button_directory_path_dir_input = gr.Button(
                "📂",
                elem_id="open_folder_small",
                elem_classes=["tool"],
                visible=(not headless),
            )
            button_directory_path_dir_input.click(
                get_folder_path,
                outputs=directory_path_dir,
                show_progress=False,
            )
        with gr.Group(), gr.Row():
            min_new_tokens = gr.Number(
                value=20,
                label="Min new tokens",
                interactive=True,
                step=1,
                minimum=5,
                maximum=300,
            )
            max_new_tokens = gr.Number(
                value=40,
                label="Max new tokens",
                interactive=True,
                step=1,
                minimum=5,
                maximum=300,
            )
            caption_file_ext = gr.Textbox(
                label="Caption file extension",
                placeholder="Extension for caption file (e.g., .caption, .txt)",
                value=".txt",
                interactive=True,
            )

        with gr.Row():
            with gr.Tab("Beam search"):
                with gr.Row():
                    num_beams = gr.Slider(
                        minimum=1,
                        maximum=16,
                        value=16,
                        step=1,
                        interactive=True,
                        label="Number of beams",
                    )

                    len_penalty = gr.Slider(
                        minimum=-1.0,
                        maximum=2.0,
                        value=1.0,
                        step=0.2,
                        interactive=True,
                        label="Length Penalty",
                        info="increase for longer sequence",
                    )

                    rep_penalty = gr.Slider(
                        minimum=1.0,
                        maximum=5.0,
                        value=1.5,
                        step=0.5,
                        interactive=True,
                        label="Repeat Penalty",
                        info="larger value prevents repetition",
                    )

                caption_button_beam = gr.Button(
                    value="Caption images", interactive=True, variant="primary"
                )
                caption_button_beam.click(
                    caption_images_beam_search,
                    inputs=[
                        directory_path_dir,
                        num_beams,
                        rep_penalty,
                        len_penalty,
                        min_new_tokens,
                        max_new_tokens,
                        caption_file_ext,
                    ],
                )
            with gr.Tab("Nucleus sampling"):
                with gr.Row():
                    do_sample = gr.Checkbox(label="Sample", value=True)
                    
                    temperature = gr.Slider(
                        minimum=0.5,
                        maximum=1.0,
                        value=1.0,
                        step=0.1,
                        interactive=True,
                        label="Temperature",
                        info="used with nucleus sampling",
                    )

                    top_p = gr.Slider(
                        minimum=0,
                        maximum=1,
                        value=0.9,
                        step=0.1,
                        interactive=True,
                        label="Top_p",
                    )

                caption_button_nucleus = gr.Button(
                    value="Caption images", interactive=True, variant="primary"
                )
                caption_button_nucleus.click(
                    caption_images_nucleus,
                    inputs=[
                        directory_path_dir,
                        do_sample,
                        temperature,
                        top_p,
                        min_new_tokens,
                        max_new_tokens,
                        caption_file_ext,
                    ],
                )