|
import streamlit as st |
|
import torch |
|
import torchvision.transforms as transforms |
|
from torchvision.models import resnet50 |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
|
|
|
|
model = resnet50(pretrained=True) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
|
|
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json" |
|
response = requests.get(LABELS_URL) |
|
labels = response.json() |
|
|
|
|
|
st.title("Image Classification with Pre-trained ResNet-50") |
|
st.write("Upload an image and the model will predict the class of the object in the image.") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file) |
|
|
|
|
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
st.write("") |
|
st.write("Classifying...") |
|
|
|
|
|
image = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(image) |
|
|
|
|
|
_, predicted = torch.max(outputs, 1) |
|
predicted_class = labels[predicted.item()] |
|
|
|
|
|
st.write(f"Predicted Class: {predicted_class}") |
|
|