import gradio as gr
import torch
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)


def predict(image, labels):
    labels = labels.split(',')
    image = preprocess(image).unsqueeze(0).to(device)
    text = clip.tokenize([f"a photo of a {c}" for c in labels]).to(device)

    with torch.inference_mode():
        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    return {k: float(v) for k, v in zip(labels, probs[0])}

# probs = predict(Image.open("../CLIP/CLIP.png"), "cat, dog, ball")
# print(probs)


gr.Interface(fn=predict,
             inputs=[
                 gr.inputs.Image(label="Image to classify.", type="pil"),
                 gr.inputs.Textbox(lines=1, label="Comma separated classes", placeholder="Enter your classes separated by ','",)],
             theme="grass",
             outputs="label",
             description="Zero Shot Image classification..").launch()