riyadifirman's picture
Update app.py
d75157f verified
raw
history blame
1.21 kB
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torchvision.transforms import Compose, Resize, ToTensor, Normalize,RandomHorizontalFlip, RandomRotation
from PIL import Image
# Load model and processor
model_name = "riyadifirman/klasifikasiburung"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)
# Define image transformations
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
transform = Compose([
Resize((224, 224)),
RandomHorizontalFlip(),
RandomRotation(10),
ToTensor(),
normalize,
])
def predict(image):
image = Image.fromarray(image)
inputs = transform(image).unsqueeze(0)
outputs = model(inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return processor.decode(predicted_class_idx)
# Create Gradio interface
interface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(type="numpy"),
outputs="text",
title="Bird Classification",
description="Upload an image of a bird to classify it."
)
if __name__ == "__main__":
interface.launch()