import torch from torchvision import transforms from torchvision import models from PIL import Image import gradio as gr import os # Use CPU device = torch.device('cpu') # Load the model ResNet-50 model architecture model = models.resnet50(pretrained=False) # Load model's weight to CPU model = torch.load('resnet50_model_weights.pth', map_location=device) model.eval() # Define the image preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Define the class names class_names = ['Abyssinian', 'American Bulldog', 'American Pit Bull Terrier', 'Basset Hound', 'Beagle', 'Bengal', 'Birman', 'Bombay', 'Boxer', 'British Shorthair', 'Chihuahua', 'Egyptian Mau', 'English Cocker Spaniel', 'English Setter', 'German Shorthaired', 'Great Pyrenees', 'Havanese', 'Japanese Chin', 'Keeshond', 'Leonberger', 'Maine Coon', 'Miniature Pinscher', 'Newfoundland', 'Persian', 'Pomeranian', 'Pug', 'Ragdoll', 'Russian Blue', 'Saint Bernard', 'Samoyed', 'Scottish Terrier', 'Shiba Inu', 'Siamese', 'Sphynx', 'Staffordshire Bull Terrier', 'Wheaten Terrier', 'Yorkshire Terrier'] # Define the predict function def classify_image(image): image = transform(image).unsqueeze(0).to(device) # Ensure image data is processed on CPU with torch.no_grad(): outputs = model(image) _, predicted = torch.max(outputs, 1) return class_names[predicted.item()] # Custom Gradio interface title, description, and article title = 'Oxford Pet 🐈🐕' description = 'A ResNet50-based computer vision model for classifying images of pets from the Oxford-IIIT Pet Dataset. The model can recognize 37 different pet breeds, including cats and dogs.' article = 'https://github.com/Eric-Chung-0511/Learning-Record/tree/main/Data%20Science%20Projects/The%20Oxford-IIIT%20Pet%20Project' # Gradio interface examples = [["examples/" + img] for img in os.listdir('examples')] demo = gr.Interface(fn=classify_image, # Map input to output function inputs=gr.Image(type="pil"), # Image input outputs=[gr.Label(num_top_classes=1, label="Predictions")], # Predicted label examples=examples, # Example images title=title, description=description, article=article) # Launch the demo demo.launch()