Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from model import ResNet50 # Assuming your model architecture is defined in a separate file called model.py | |
# Load the model | |
model = ResNet50() | |
model.load_state_dict(torch.load('best_modelv2.pth', map_location=torch.device('cpu'))) | |
model.eval() | |
# Define transform for input images | |
data_transforms = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Function to predict image label | |
def predict_image_label(image): | |
# Preprocess the image | |
image = data_transforms(image).unsqueeze(0) | |
# Make prediction | |
with torch.no_grad(): | |
output = model(image) | |
_, predicted = torch.max(output, 1) | |
return predicted.item() | |
# Streamlit app | |
st.title("Leaf or Plant Classifier") | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Display the uploaded image | |
image = Image.open(uploaded_file) | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
# Classify the image | |
prediction = predict_image_label(image) | |
label = 'Leaf' if prediction == 0 else 'Plant' | |
st.write(f"Prediction: {label}") | |