import gradio as gr import torch from torchvision import transforms, models from torch import nn from PIL import Image # Load the model architecture model = models.resnet50(weights=None) num_classes = 30 num_features = model.fc.in_features model.fc = nn.Linear(num_features, num_classes) # Load the trained model weights try: model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu'))) print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") # Load your trained model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) model.eval() # Define the image transformations (adjust as needed for your model) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Define class labels class_labels = [ "aerosol_cans", "aluminum_food_cans", "aluminum_soda_cans", "cardboard_boxes", "cardboard_packaging", "clothing", "coffee_grounds", "disposable_plastic_cutlery", "eggshells", "food_waste", "glass_beverage_bottles", "glass_cosmetic_containers", "glass_food_jars", "magazines", "newspaper", "office_paper", "paper_cups", "plastic_cup_lids", "plastic_detergent_bottles", "plastic_food_containers", "plastic_shopping_bags", "plastic_soda_bottles", "plastic_straws", "plastic_trash_bags", "plastic_water_bottles", "shoes", "steel_food_cans", "styrofoam_cups", "styrofoam_food_containers", "tea_bags" ] # Prediction function def predict_image(image): if image.mode != "RGB": image = image.convert("RGB") input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(input_tensor) _, predicted = torch.max(outputs, 1) label = class_labels[predicted.item()] return label # Gradio interface setup interface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil", label="Upload Image"), outputs="text", live=True ) # Launch Gradio app interface.launch()