import spaces
import gradio as gr

import torch
import matplotlib.pyplot as plt 
from PIL import Image, ImageDraw, ImageFont 
import requests 
from io import BytesIO 
import numpy as np 

# load a simple face detector 
from retinaface import RetinaFace 

device = "cuda" if torch.cuda.is_available() else "cpu"

# load Gaze-LLE model
model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout")

def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None):
    if isinstance(heatmap, torch.Tensor):
        heatmap = heatmap.detach().cpu().numpy()
    heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR)
    heatmap = / 255.)
    heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
    heatmap = Image.fromarray(heatmap).convert("RGBA")
    overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap)

    if bbox is not None:
        width, height = pil_image.size
        xmin, ymin, xmax, ymax = bbox
        draw = ImageDraw.Draw(overlay_image)
        draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01))

        if inout_score is not None:
            text = f"in-frame: {inout_score:.2f}"
            text_width = draw.textlength(text)
            text_height = int(height * 0.01)
            text_x = xmin * width
            text_y = ymax * height + text_height
            draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
    return overlay_image

def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5):
    colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']
    overlay_image = pil_image.convert("RGBA")
    draw = ImageDraw.Draw(overlay_image)
    width, height = pil_image.size

    for i in range(len(bboxes)):
        bbox = bboxes[i]
        xmin, ymin, xmax, ymax = bbox
        color = colors[i % len(colors)]
        draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01))

        if inout_scores is not None:
            inout_score = inout_scores[i]
            text = f"in-frame: {inout_score:.2f}"
            text_width = draw.textlength(text)
            text_height = int(height * 0.01)
            text_x = xmin * width
            text_y = ymax * height + text_height
            draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05)))

        if inout_scores is not None and inout_score > inout_thresh:
            heatmap = heatmaps[i]
            heatmap_np = heatmap.detach().cpu().numpy()
            max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)
            gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
            gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
            bbox_center_x = ((xmin + xmax) / 2) * width
            bbox_center_y = ((ymin + ymax) / 2) * height

            draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height)))
            draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height)))

    return overlay_image

@spaces.GPU() # ZeroGPU ready
def main(image_input, progress=gr.Progress(track_tqdm=True)):
    # load image
    image =
    width, height = image.size

    # detect faces
    resp = RetinaFace.detect_faces(np.array(image))
    bboxes = [resp[key]["facial_area"] for key in resp.keys()]

    # prepare gazelle input
    img_tensor = transform(image).unsqueeze(0).to(device)
    norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]]

    input = {
        "images": img_tensor, # [num_images, 3, 448, 448]
        "bboxes": norm_bboxes # [[img1_bbox1, img1_bbox2...], [img2_bbox1, img2_bbox2]...]

    with torch.no_grad():
        output = model(input)

    img1_person1_heatmap = output['heatmap'][0][0] # [64, 64] heatmap
    if model.inout:
        img1_person1_inout = output['inout'][0][0] # gaze in frame score (if model supports inout prediction)

    # visualize predicted gaze heatmap for each person and gaze in/out of frame score
    heatmap_results = []
    for i in range(len(bboxes)):
        overlay_img = visualize_heatmap(image, output['heatmap'][0][i], norm_bboxes[0][i], inout_score=output['inout'][0][i] if output['inout'] is not None else None)

    # combined visualization with maximal gaze points for each person
    result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5)

    return result_gazed, heatmap_results

    margin: 0 auto;
    max-width: 982px;

with gr.Blocks(css=css) as demo: 
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders")
        gr.Markdown("A transformer approach for estimating gaze targets that leverages the power of pretrained visual foundation models. Gaze-LLE provides a streamlined gaze architecture that learns only a lightweight gaze decoder on top of a frozen, pretrained visual encoder (DINOv2). Gaze-LLE learns 1-2 orders of magnitude fewer parameters than prior works and doesn't require any extra input modalities like depth and pose!")
        <div style="display:flex;column-gap:4px;">
            <a href="">
                <img src=''>
            <a href="">
                <img src=''>
            <a href="">
                <img src="" alt="Duplicate this Space">
            <a href="">
                <img src="" alt="Follow me on HF">
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Image Input", type="filepath")
                submit_button = gr.Button("Submit")
                    examples = ["examples/the_office.png", "examples/succession.png"],
                    inputs = [input_image]
            with gr.Column():
                result = gr.Image(label="Result")
                heatmaps = gr.Gallery(label="Heatmap", columns=3)
        fn = main,
        inputs = [input_image],
        outputs = [result, heatmaps]
demo.queue().launch(show_api=False, show_error=True)