Spaces:
Running
Running
import gradio as gr | |
import torch | |
from PIL import Image | |
import torchvision.transforms as transforms | |
import numpy as np | |
import torch.nn.functional as F | |
from safetensors.torch import load_model, save_model | |
from models import * | |
import os | |
class WasteClassifier: | |
def __init__(self, model, class_names, device): | |
self.model = model | |
self.class_names = class_names | |
self.device = device | |
self.transform = transforms.Compose( | |
[ | |
transforms.Resize((384, 384)), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | |
), | |
] | |
) | |
def predict(self, image): | |
self.model.eval() | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
original_size = image.size | |
img_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
with torch.no_grad(): | |
outputs, seg_mask = self.model(img_tensor) # Handle both outputs | |
probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
probs = probabilities[0].cpu().numpy() | |
pred_class = self.class_names[np.argmax(probs)] | |
confidence = np.max(probs) | |
# Process segmentation mask | |
seg_mask = ( | |
seg_mask[0, 0].cpu().numpy().astype(np.float32) | |
) # Get first image, first channel | |
# seg_mask = (seg_mask >= 0.2).astype(np.float32) # Threshold at 0.2 | |
# Resize mask back to original image size | |
seg_mask = Image.fromarray(seg_mask) | |
seg_mask = seg_mask.resize(original_size, Image.NEAREST) | |
seg_mask = np.array(seg_mask) | |
results = { | |
"predicted_class": pred_class, | |
"confidence": confidence, | |
"class_probabilities": { | |
class_name: float(prob) | |
for class_name, prob in zip(self.class_names, probs) | |
}, | |
"segmentation_mask": seg_mask, | |
} | |
return results | |
def interface(classifier): | |
def process_image(image): | |
results = classifier.predict(image) | |
if isinstance(image, Image.Image): | |
image_np = np.array(image) | |
else: | |
image_np = image | |
mask = results["segmentation_mask"] | |
overlay = image_np.copy() | |
overlay[mask < 0.2] = overlay[mask < 0.2] * 0 | |
output_str = f"Predicted Class: {results['predicted_class']}\n" | |
output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n" | |
output_str += "Class Probabilities:\n" | |
sorted_probs = sorted( | |
results["class_probabilities"].items(), key=lambda x: x[1], reverse=True | |
) | |
for class_name, prob in sorted_probs: | |
output_str += f"{class_name}: {prob*100:.2f}%\n" | |
mask_viz = (mask * 255).astype(np.uint8) | |
return [output_str, overlay, mask_viz] | |
demo = gr.Interface( | |
fn=process_image, | |
inputs=[gr.Image(type="pil", label="Upload Image")], | |
outputs=[ | |
gr.Textbox(label="Classification Results"), | |
gr.Image(label="Segmented Object"), | |
gr.Image(label="Segmentation Mask"), | |
], | |
title="Waste Classification System", | |
description=""" | |
Upload an image of waste to classify it into different categories. | |
The model will predict the type of waste, show confidence scores for each category, | |
and display the segmented object along with its mask. | |
""", | |
examples=( | |
[["example1.jpg"], ["example2.jpg"], ["example3.jpg"]] | |
if os.path.exists("example1.jpg") | |
else None | |
), | |
analytics_enabled=False, | |
theme="default", | |
) | |
return demo | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class_names = [ | |
"Cardboard", | |
"Food Organics", | |
"Glass", | |
"Metal", | |
"Miscellaneous Trash", | |
"Paper", | |
"Plastic", | |
"Textile Trash", | |
"Vegetation", | |
] | |
best_model = ResNet101UNet(num_classes=len(class_names)) | |
best_model = best_model.to(device) | |
load_model( | |
best_model, | |
os.path.join(os.path.dirname(os.path.abspath(__file__)), "3q7y4e.safetensors"), | |
) | |
classifier = WasteClassifier(best_model, class_names, device) | |
demo = interface(classifier) | |
demo.launch() | |