Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation | |
from PIL import Image | |
import traceback | |
# Load dataset to get labels | |
dataset = load_dataset("bentrevett/caltech-ucsd-birds-200-2011") | |
labels = dataset['train'].features['label'].names | |
# Load model and processor | |
model_name = "riyadifirman/klasifikasiburung" | |
processor = AutoImageProcessor.from_pretrained(model_name) | |
model = AutoModelForImageClassification.from_pretrained(model_name) | |
# Define image transformations | |
normalize = Normalize(mean=processor.image_mean, std=processor.image_std) | |
transform = Compose([ | |
Resize((224, 224)), | |
RandomHorizontalFlip(), | |
RandomRotation(10), | |
ToTensor(), | |
normalize, | |
]) | |
def predict(image): | |
try: | |
image = Image.fromarray(image) | |
inputs = transform(image).unsqueeze(0) | |
outputs = model(inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
predicted_class = labels[predicted_class_idx] | |
return predicted_class | |
except Exception as e: | |
# Menampilkan error | |
print("An error occurred:", e) | |
print(traceback.format_exc()) # Ini akan print error secara detail | |
return "An error occurred while processing your request." | |
# Create Gradio interface | |
interface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="numpy"), | |
outputs="text", | |
title="Bird Classification", | |
description="Upload an image of a bird to classify it." | |
) | |
if __name__ == "__main__": | |
interface.launch(share=True) | |