File size: 1,679 Bytes
0d599fc
 
 
b5e285e
0d599fc
b5e285e
0d599fc
20115ca
 
 
 
0d599fc
 
 
 
 
 
 
 
 
d75157f
 
0d599fc
 
 
 
 
b5e285e
 
 
 
 
 
20115ca
 
b5e285e
 
 
 
 
0d599fc
 
 
 
b5e285e
0d599fc
 
 
 
 
 
20115ca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomRotation
from PIL import Image
import traceback

# Load dataset to get labels
dataset = load_dataset("bentrevett/caltech-ucsd-birds-200-2011")
labels = dataset['train'].features['label'].names

# Load model and processor
model_name = "riyadifirman/klasifikasiburung"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)

# Define image transformations
normalize = Normalize(mean=processor.image_mean, std=processor.image_std)
transform = Compose([
    Resize((224, 224)),
    RandomHorizontalFlip(),
    RandomRotation(10),
    ToTensor(),
    normalize,
])

def predict(image):
    try:
        image = Image.fromarray(image)
        inputs = transform(image).unsqueeze(0)
        outputs = model(inputs)
        logits = outputs.logits
        predicted_class_idx = logits.argmax(-1).item()
        predicted_class = labels[predicted_class_idx]
        return predicted_class
    except Exception as e:
        # Menampilkan error
        print("An error occurred:", e)
        print(traceback.format_exc())  # Ini akan print error secara detail
        return "An error occurred while processing your request."

# Create Gradio interface
interface = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="numpy"),
    outputs="text",
    title="Bird Classification",
    description="Upload an image of a bird to classify it."
)

if __name__ == "__main__":
    interface.launch(share=True)