demo-ml-v2 / app.py
spuuntries
feat: add new model
c9e9eb6
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
from safetensors.torch import load_model, save_model
from models import *
import os
class WasteClassifier:
def __init__(self, model, class_names, device):
self.model = model
self.class_names = class_names
self.device = device
self.transform = transforms.Compose(
[
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def predict(self, image):
self.model.eval()
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
original_size = image.size
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs, seg_mask = self.model(img_tensor) # Handle both outputs
probabilities = torch.nn.functional.softmax(outputs, dim=1)
probs = probabilities[0].cpu().numpy()
pred_class = self.class_names[np.argmax(probs)]
confidence = np.max(probs)
# Process segmentation mask
seg_mask = (
seg_mask[0, 0].cpu().numpy().astype(np.float32)
) # Get first image, first channel
# seg_mask = (seg_mask >= 0.2).astype(np.float32) # Threshold at 0.2
# Resize mask back to original image size
seg_mask = Image.fromarray(seg_mask)
seg_mask = seg_mask.resize(original_size, Image.NEAREST)
seg_mask = np.array(seg_mask)
results = {
"predicted_class": pred_class,
"confidence": confidence,
"class_probabilities": {
class_name: float(prob)
for class_name, prob in zip(self.class_names, probs)
},
"segmentation_mask": seg_mask,
}
return results
def interface(classifier):
def process_image(image):
results = classifier.predict(image)
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
mask = results["segmentation_mask"]
overlay = image_np.copy()
overlay[mask < 0.2] = overlay[mask < 0.2] * 0
output_str = f"Predicted Class: {results['predicted_class']}\n"
output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n"
output_str += "Class Probabilities:\n"
sorted_probs = sorted(
results["class_probabilities"].items(), key=lambda x: x[1], reverse=True
)
for class_name, prob in sorted_probs:
output_str += f"{class_name}: {prob*100:.2f}%\n"
mask_viz = (mask * 255).astype(np.uint8)
return [output_str, overlay, mask_viz]
demo = gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil", label="Upload Image")],
outputs=[
gr.Textbox(label="Classification Results"),
gr.Image(label="Segmented Object"),
gr.Image(label="Segmentation Mask"),
],
title="Waste Classification System",
description="""
Upload an image of waste to classify it into different categories.
The model will predict the type of waste, show confidence scores for each category,
and display the segmented object along with its mask.
""",
examples=(
[["example1.jpg"], ["example2.jpg"], ["example3.jpg"]]
if os.path.exists("example1.jpg")
else None
),
analytics_enabled=False,
theme="default",
)
return demo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = [
"Cardboard",
"Food Organics",
"Glass",
"Metal",
"Miscellaneous Trash",
"Paper",
"Plastic",
"Textile Trash",
"Vegetation",
]
best_model = ResNet101UNet(num_classes=len(class_names))
best_model = best_model.to(device)
load_model(
best_model,
os.path.join(os.path.dirname(os.path.abspath(__file__)), "3q7y4e.safetensors"),
)
classifier = WasteClassifier(best_model, class_names, device)
demo = interface(classifier)
demo.launch()