riyadifirman's picture
Update app.py
20115ca verified
raw
history blame
1.68 kB
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation
from PIL import Image
import traceback
# Load dataset to get labels
dataset = load_dataset("bentrevett/caltech-ucsd-birds-200-2011")
labels = dataset['train'].features['label'].names
# Load model and processor
model_name = "riyadifirman/klasifikasiburung"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
# Define image transformations
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
transform = Compose([
Resize((224, 224)),
RandomHorizontalFlip(),
RandomRotation(10),
ToTensor(),
normalize,
])
def predict(image):
try:
image = Image.fromarray(image)
inputs = transform(image).unsqueeze(0)
outputs = model(inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_class = labels[predicted_class_idx]
return predicted_class
except Exception as e:
# Menampilkan error
print("An error occurred:", e)
print(traceback.format_exc()) # Ini akan print error secara detail
return "An error occurred while processing your request."
# Create Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="numpy"),
outputs="text",
title="Bird Classification",
description="Upload an image of a bird to classify it."
)
if __name__ == "__main__":
interface.launch(share=True)