import gradio as gr import torch from PIL import Image import torchvision.transforms as transforms import numpy as np import torch.nn.functional as F from safetensors.torch import load_model, save_model from models import * import os class WasteClassifier: def __init__(self, model, class_names, device): self.model = model self.class_names = class_names self.device = device self.transform = transforms.Compose( [ transforms.Resize((384, 384)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) def predict(self, image): self.model.eval() if not isinstance(image, Image.Image): image = Image.fromarray(image) original_size = image.size img_tensor = self.transform(image).unsqueeze(0).to(self.device) with torch.no_grad(): outputs, seg_mask = self.model(img_tensor) # Handle both outputs probabilities = torch.nn.functional.softmax(outputs, dim=1) probs = probabilities[0].cpu().numpy() pred_class = self.class_names[np.argmax(probs)] confidence = np.max(probs) # Process segmentation mask seg_mask = ( seg_mask[0, 0].cpu().numpy().astype(np.float32) ) # Get first image, first channel # seg_mask = (seg_mask >= 0.2).astype(np.float32) # Threshold at 0.2 # Resize mask back to original image size seg_mask = Image.fromarray(seg_mask) seg_mask = seg_mask.resize(original_size, Image.NEAREST) seg_mask = np.array(seg_mask) results = { "predicted_class": pred_class, "confidence": confidence, "class_probabilities": { class_name: float(prob) for class_name, prob in zip(self.class_names, probs) }, "segmentation_mask": seg_mask, } return results def interface(classifier): def process_image(image): results = classifier.predict(image) if isinstance(image, Image.Image): image_np = np.array(image) else: image_np = image mask = results["segmentation_mask"] overlay = image_np.copy() overlay[mask < 0.2] = overlay[mask < 0.2] * 0 output_str = f"Predicted Class: {results['predicted_class']}\n" output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n" output_str += "Class Probabilities:\n" sorted_probs = sorted( results["class_probabilities"].items(), key=lambda x: x[1], reverse=True ) for class_name, prob in sorted_probs: output_str += f"{class_name}: {prob*100:.2f}%\n" mask_viz = (mask * 255).astype(np.uint8) return [output_str, overlay, mask_viz] demo = gr.Interface( fn=process_image, inputs=[gr.Image(type="pil", label="Upload Image")], outputs=[ gr.Textbox(label="Classification Results"), gr.Image(label="Segmented Object"), gr.Image(label="Segmentation Mask"), ], title="Waste Classification System", description=""" Upload an image of waste to classify it into different categories. The model will predict the type of waste, show confidence scores for each category, and display the segmented object along with its mask. """, examples=( [["example1.jpg"], ["example2.jpg"], ["example3.jpg"]] if os.path.exists("example1.jpg") else None ), analytics_enabled=False, theme="default", ) return demo device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class_names = [ "Cardboard", "Food Organics", "Glass", "Metal", "Miscellaneous Trash", "Paper", "Plastic", "Textile Trash", "Vegetation", ] best_model = ResNet101UNet(num_classes=len(class_names)) best_model = best_model.to(device) load_model( best_model, os.path.join(os.path.dirname(os.path.abspath(__file__)), "3q7y4e.safetensors"), ) classifier = WasteClassifier(best_model, class_names, device) demo = interface(classifier) demo.launch()