X-RayDemo / app.py
tousin23's picture
Update app.py
7755c1c verified
raw
history blame
1.61 kB
import streamlit as st
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
from PIL import Image
import requests
from io import BytesIO
# Load the pre-trained ResNet-50 model
model = resnet50(pretrained=True)
model.eval()
# Define the image transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define the label map for ImageNet classes
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
response = requests.get(LABELS_URL)
labels = response.json()
# Streamlit UI
st.title("Image Classification with Pre-trained ResNet-50")
st.write("Upload an image and the model will predict the class of the object in the image.")
# File uploader
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Open the image file
image = Image.open(uploaded_file)
# Display the image
st.image(image, caption='Uploaded Image', use_column_width=True)
st.write("")
st.write("Classifying...")
# Preprocess the image
image = transform(image).unsqueeze(0)
# Predict the class
with torch.no_grad():
outputs = model(image)
# Get the predicted class
_, predicted = torch.max(outputs, 1)
predicted_class = labels[predicted.item()]
# Display the result
st.write(f"Predicted Class: {predicted_class}")