import streamlit as st import torch import torchvision.transforms as transforms from torchvision.models import resnet50 from PIL import Image import requests from io import BytesIO # Load the pre-trained ResNet-50 model model = resnet50(pretrained=True) model.eval() # Define the image transforms transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Define the label map for ImageNet classes LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" response = requests.get(LABELS_URL) labels = response.json() # Streamlit UI st.title("Image Classification with Pre-trained ResNet-50") st.write("Upload an image and the model will predict the class of the object in the image.") # File uploader uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Open the image file image = Image.open(uploaded_file) # Display the image st.image(image, caption='Uploaded Image', use_column_width=True) st.write("") st.write("Classifying...") # Preprocess the image image = transform(image).unsqueeze(0) # Predict the class with torch.no_grad(): outputs = model(image) # Get the predicted class _, predicted = torch.max(outputs, 1) predicted_class = labels[predicted.item()] # Display the result st.write(f"Predicted Class: {predicted_class}")