Devon12's picture
Update app.py
355cc40 verified
import gradio as gr
import torch
from torchvision import transforms, models
from torch import nn
from PIL import Image
# Load the model architecture
model = models.resnet50(weights=None)
num_classes = 30
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
# Load the trained model weights
try:
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# Load your trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# Define the image transformations (adjust as needed for your model)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Define class labels
class_labels = [
"aerosol_cans", "aluminum_food_cans", "aluminum_soda_cans", "cardboard_boxes",
"cardboard_packaging", "clothing", "coffee_grounds", "disposable_plastic_cutlery",
"eggshells", "food_waste", "glass_beverage_bottles", "glass_cosmetic_containers",
"glass_food_jars", "magazines", "newspaper", "office_paper", "paper_cups",
"plastic_cup_lids", "plastic_detergent_bottles", "plastic_food_containers",
"plastic_shopping_bags", "plastic_soda_bottles", "plastic_straws", "plastic_trash_bags",
"plastic_water_bottles", "shoes", "steel_food_cans", "styrofoam_cups",
"styrofoam_food_containers", "tea_bags"
]
# Prediction function
def predict_image(image):
if image.mode != "RGB":
image = image.convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
label = class_labels[predicted.item()]
return label
# Gradio interface setup
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs="text",
live=True
)
# Launch Gradio app
interface.launch()