import io
import os
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt

from transformers import AutoFeatureExtractor, YolosForObjectDetection
from PIL import Image


COLORS = [
    [0.000, 0.447, 0.741],
    [0.850, 0.325, 0.098],
    [0.929, 0.694, 0.125],
    [0.494, 0.184, 0.556],
    [0.466, 0.674, 0.188],
    [0.301, 0.745, 0.933],
]


def process_class_list(classes_string: str):
    if classes_string == "":
        return []
    classes_list = classes_string.split(",")
    classes_list = [x.strip() for x in classes_list]
    return classes_list


def model_inference(img, prob_threshold, classes_to_show):
    feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/yolos-small-dwr")
    model = YolosForObjectDetection.from_pretrained(f"hustvl/yolos-small-dwr")
    img = Image.fromarray(img)
    pixel_values = feature_extractor(img, return_tensors="pt").pixel_values

    with torch.no_grad():
        outputs = model(pixel_values, output_attentions=True)

    probas = outputs.logits.softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > prob_threshold
    target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
    postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
    bboxes_scaled = postprocessed_outputs[0]["boxes"]
    classes_list = process_class_list(classes_to_show)
    res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)

    return res_img


def plot_results(pil_img, prob, boxes, model, classes_list):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        cl = p.argmax()
        object_class = model.config.id2label[cl.item()]
        if len(classes_list) > 0:
            if object_class not in classes_list:
                continue
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
        text = f"{object_class}: {p[cl]:0.2f}"
        ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
    plt.axis("off")
    return fig2img(plt.gcf())


def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img


description = """Upload an image and get the detected classes"""
title = """Object Detection"""

# Create examples list from "examples/" directory
# example_list = [["examples/" + example] for example in os.listdir("examples")]
# example_list = [["carplane.webp"]]

image_in = gr.components.Image(label="Upload an image")
image_out = gr.components.Image()
classes_to_show = gr.components.Textbox(placeholder="e.g. car, dog", label="Classes to filter (leave empty to detect all classes)")
prob_threshold_slider = gr.components.Slider(minimum=0, maximum=1.0, step=0.01, value=0.7, label="Probability Threshold")
inputs = [image_in, prob_threshold_slider, classes_to_show]
# gr.Examples([['carplane.webp'], ['CTH.png']], inputs=image_in)

gr.Interface(fn=model_inference,
             inputs=inputs,
             outputs=image_out,
             title=title,
             description=description,
             # examples=example_list
).launch()