File size: 2,135 Bytes
4025a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
355cc40
4025a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17224be
4025a0b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
import torch
from torchvision import transforms, models
from torch import nn
from PIL import Image

# Load the model architecture
model = models.resnet50(weights=None)
num_classes = 30
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)

# Load the trained model weights
try:
    model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')))
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")

# Load your trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

# Define the image transformations (adjust as needed for your model)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define class labels
class_labels = [
    "aerosol_cans", "aluminum_food_cans", "aluminum_soda_cans", "cardboard_boxes", 
    "cardboard_packaging", "clothing", "coffee_grounds", "disposable_plastic_cutlery", 
    "eggshells", "food_waste", "glass_beverage_bottles", "glass_cosmetic_containers", 
    "glass_food_jars", "magazines", "newspaper", "office_paper", "paper_cups", 
    "plastic_cup_lids", "plastic_detergent_bottles", "plastic_food_containers", 
    "plastic_shopping_bags", "plastic_soda_bottles", "plastic_straws", "plastic_trash_bags", 
    "plastic_water_bottles", "shoes", "steel_food_cans", "styrofoam_cups", 
    "styrofoam_food_containers", "tea_bags"
]

# Prediction function
def predict_image(image):
    if image.mode != "RGB":
        image = image.convert("RGB")
    input_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(input_tensor)
        _, predicted = torch.max(outputs, 1)
        label = class_labels[predicted.item()]
    return label

# Gradio interface setup
interface = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs="text",
    live=True
)

# Launch Gradio app
interface.launch()