Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification | |
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
from PIL import Image | |
from datasets import load_dataset | |
# Load your fine-tuned model and dataset | |
processor = AutoProcessor.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets") | |
model = AutoModelForZeroShotImageClassification.from_pretrained("DGurgurov/clip-vit-base-patch32-oxford-pets") | |
# Load dataset to get labels | |
dataset = load_dataset("pcuenq/oxford-pets") # Adjust dataset loading as per your setup | |
labels = list(set(dataset['train']['label'])) | |
label2id = {label: i for i, label in enumerate(labels)} | |
id2label = {i: label for label, i in label2id.items()} | |
# Function to classify image using CLIP model | |
def classify_image(image): | |
# Preprocess the image | |
image = Image.fromarray(image) | |
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) | |
# Run inference | |
outputs = model(**inputs) | |
# Extract logits and apply softmax | |
logits_per_image = outputs.logits_per_image # logits_per_image is a tensor with shape [1, num_labels] | |
probs = logits_per_image[0].softmax(dim=0) # Take the softmax across the labels | |
# Get predicted label id and score | |
predicted_label_id = probs.argmax().item() | |
predicted_label = id2label[predicted_label_id] | |
return predicted_label | |
# Gradio interface | |
iface = gr.Interface( | |
fn=classify_image, | |
inputs=gr.Image(label="Upload a picture of an animal"), | |
outputs=gr.Textbox(label="Predicted Animal"), | |
title="Animal Classifier", | |
description="CLIP-based model fine-tuned on Oxford Pets dataset to classify animals.", | |
) | |
# Launch the Gradio interface | |
iface.launch() |