Devon12 commited on
Commit
4025a0b
1 Parent(s): 1fc529a

Add Gradio app, requirements, and model

Browse files
Files changed (2) hide show
  1. app.py +65 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms, models
4
+ from torch import nn
5
+ from PIL import Image
6
+
7
+ # Load the model architecture
8
+ model = models.resnet50(weights=None)
9
+ num_classes = 30
10
+ num_features = model.fc.in_features
11
+ model.fc = nn.Linear(num_features, num_classes)
12
+
13
+ # Load the trained model weights
14
+ try:
15
+ model.load_state_dict(torch.load("best.pt", map_location=torch.device('cpu')))
16
+ print("Model loaded successfully.")
17
+ except Exception as e:
18
+ print(f"Error loading model: {e}")
19
+
20
+ # Load your trained model
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ model = model.to(device)
23
+ model.eval()
24
+
25
+ # Define the image transformations (adjust as needed for your model)
26
+ transform = transforms.Compose([
27
+ transforms.Resize((224, 224)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
30
+ ])
31
+
32
+ # Define class labels
33
+ class_labels = [
34
+ "aerosol_cans", "aluminum_food_cans", "aluminum_soda_cans", "cardboard_boxes",
35
+ "cardboard_packaging", "clothing", "coffee_grounds", "disposable_plastic_cutlery",
36
+ "eggshells", "food_waste", "glass_beverage_bottles", "glass_cosmetic_containers",
37
+ "glass_food_jars", "magazines", "newspaper", "office_paper", "paper_cups",
38
+ "plastic_cup_lids", "plastic_detergent_bottles", "plastic_food_containers",
39
+ "plastic_shopping_bags", "plastic_soda_bottles", "plastic_straws", "plastic_trash_bags",
40
+ "plastic_water_bottles", "shoes", "steel_food_cans", "styrofoam_cups",
41
+ "styrofoam_food_containers", "tea_bags"
42
+ ]
43
+
44
+ # Prediction function
45
+ def predict_image(image):
46
+ if image.mode != "RGB":
47
+ image = image.convert("RGB")
48
+ input_tensor = transform(image).unsqueeze(0).to(device)
49
+ with torch.no_grad():
50
+ outputs = model(input_tensor)
51
+ _, predicted = torch.max(outputs, 1)
52
+ label = class_labels[predicted.item()]
53
+ return label
54
+
55
+ # Gradio interface setup
56
+ interface = gr.Interface(
57
+ fn=predict_image,
58
+ inputs=gr.Image(type="pil", label="Upload Image or Use Webcam"),
59
+ outputs="text",
60
+ live=True
61
+ )
62
+
63
+ # Launch Gradio app
64
+ interface.launch()
65
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==5.6.0
2
+ torch
3
+ torchvision