DawnC's picture
Upload 5 files
37f6bf3 verified
raw
history blame
2.52 kB
import torch
from torchvision import transforms
from torchvision import models
from PIL import Image
import gradio as gr
import os
# Use CPU
device = torch.device('cpu')
# Load the model ResNet-50 model architecture
model = models.resnet50(pretrained=False)
# Load model's weight to CPU
model = torch.load('resnet50_model_weights.pth', map_location=device)
model.eval()
# Define the image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Define the class names
class_names = ['Abyssinian', 'American Bulldog', 'American Pit Bull Terrier', 'Basset Hound', 'Beagle', 'Bengal', 'Birman', 'Bombay',
'Boxer', 'British Shorthair', 'Chihuahua', 'Egyptian Mau', 'English Cocker Spaniel', 'English Setter', 'German Shorthaired',
'Great Pyrenees', 'Havanese', 'Japanese Chin', 'Keeshond', 'Leonberger', 'Maine Coon', 'Miniature Pinscher', 'Newfoundland',
'Persian', 'Pomeranian', 'Pug', 'Ragdoll', 'Russian Blue', 'Saint Bernard', 'Samoyed', 'Scottish Terrier', 'Shiba Inu',
'Siamese', 'Sphynx', 'Staffordshire Bull Terrier', 'Wheaten Terrier', 'Yorkshire Terrier']
# Define the predict function
def classify_image(image):
image = transform(image).unsqueeze(0).to(device) # Ensure image data is processed on CPU
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs, 1)
return class_names[predicted.item()]
# Custom Gradio interface title, description, and article
title = 'Oxford Pet πŸˆπŸ•'
description = 'A ResNet50-based computer vision model for classifying images of pets from the Oxford-IIIT Pet Dataset. The model can recognize 37 different pet breeds, including cats and dogs.'
article = 'https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project'
# Gradio interface
examples = [["examples/" + img] for img in os.listdir('examples')]
demo = gr.Interface(fn=classify_image, # Map input to output function
inputs=gr.Image(type="pil"), # Image input
outputs=[gr.Label(num_top_classes=1, label="Predictions")], # Predicted label
examples=examples, # Example images
title=title,
description=description,
article=article)
# Launch the demo
demo.launch()