import gradio as gr
from transformers import AutoImageProcessor, AutoModel
import torch
from PIL import Image
import json
import numpy as np
import faiss


# Init similarity search AI model and processor
device = torch.device("cpu")
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
model = AutoModel.from_pretrained("facebook/dinov2-large")
model.config.return_dict = False  # Set return_dict to False for JIT tracing
model.to(device)

# Prepare an example input for tracing
example_input = torch.rand(1, 3, 224, 224).to(device)  # Adjust size if needed
traced_model = torch.jit.trace(model, example_input)
traced_model = traced_model.to(device)

# Load faiss index
index = faiss.read_index("xbgp-faiss.index")

# Load faiss map
with open("xbgp-faiss-map.json", "r") as f:
    images = json.load(f)


def process_image(image):
    """
    Process the image and extract features using the DINOv2 model.
    """
    # Add your image processing code here.
    # This will include preprocessing the image, passing it through the model,
    # and then formatting the output (extracted features).

    # Convert to RGB if it isn't already
    if image.mode != "RGB":
        image = image.convert("RGB")

    # Resize to 224px while maintaining aspect ratio
    width, height = image.size
    if width < height:
        w_percent = 224 / float(width)
        new_width = 224
        new_height = int(float(height) * float(w_percent))
    else:
        h_percent = 224 / float(height)
        new_height = 224
        new_width = int(float(width) * float(h_percent))
    image = image.resize((new_width, new_height), Image.LANCZOS)

    # Extract the features from the uploaded image
    with torch.no_grad():
        inputs = processor(images=image, return_tensors="pt")["pixel_values"].to(device)

        # Use the traced model for inference
        outputs = traced_model(inputs)

    # Normalize the features before search, whatever that means
    embeddings = outputs[0].mean(dim=1)
    vector = embeddings.detach().cpu().numpy()
    vector = np.float32(vector)
    faiss.normalize_L2(vector)

    # Read the index file and perform search of top 50 images
    distances, indices = index.search(vector, 50)

    matches = []

    for idx, matching_gamerpic in enumerate(indices[0]):
        gamerpic = {}
        gamerpic["id"] = images[matching_gamerpic]
        gamerpic["score"] = str(round((1 / (distances[0][idx] + 1) * 100), 2)) + "%"

        matches.append(gamerpic)

    return matches


# Create a Gradio interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil"),  # Adjust the shape as needed
    outputs="json",  # Or any other output format that suits your needs
).queue()

# Launch the Gradio app
iface.launch()