KabeerAmjad's picture
Update app.py
53b38ec verified
raw
history blame
1.97 kB
import gradio as gr
import torch
from transformers import AutoFeatureExtractor
from torchvision import models, transforms
from PIL import Image
# Load your trained model from Hugging Face (if available) or load locally
model_id = "KabeerAmjad/food_classification_model" # Replace with your actual model ID
model = models.resnet50() # Load ResNet50 architecture
model.load_state_dict(torch.load("path_to_trained_model_weights.pth")) # Load the trained weights
model.eval() # Set to evaluation mode
# Load the feature extractor (can be used if any custom preprocessing was applied)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
# Define the prediction function
def classify_image(img):
# Preprocess the image to match ResNet50's expected input format
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img_tensor = preprocess(img).unsqueeze(0) # Add batch dimension
# Make prediction with the model
with torch.no_grad():
outputs = model(img_tensor)
probs = torch.softmax(outputs, dim=1)
# Get the label with the highest probability
_, predicted_class = torch.max(probs, 1)
# If you have a list of class labels, use it
class_labels = ["Apple Pie", "Burger", "Pizza", "Tacos"] # Replace with your actual class labels
predicted_label = class_labels[predicted_class.item()]
return predicted_label
# Create the Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="pil"),
outputs="text",
title="Food Image Classification",
description="Upload an image to classify if it’s an apple pie, burger, pizza, etc."
)
# Launch the app
iface.launch()