demo-ml-v2 / app.py
spuuntries
feat: add project files
8f7598e
raw
history blame
3.27 kB
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
from safetensors.torch import load_model, save_model
from models import *
import os
class WasteClassifier:
def __init__(self, model, class_names, device):
self.model = model
self.class_names = class_names
self.device = device
self.transform = transforms.Compose(
[
transforms.Resize((384, 384)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
def predict(self, image):
self.model.eval()
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
original_size = image.size
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
probs = probabilities[0].cpu().numpy()
pred_class = self.class_names[np.argmax(probs)]
confidence = np.max(probs)
results = {
"predicted_class": pred_class,
"confidence": confidence,
"class_probabilities": {
class_name: float(prob)
for class_name, prob in zip(self.class_names, probs)
},
}
return results
def interface(classifier):
def process_image(image):
results = classifier.predict(image)
output_str = f"Predicted Class: {results['predicted_class']}\n"
output_str += f"Confidence: {results['confidence']*100:.2f}%\n\n"
output_str += "Class Probabilities:\n"
sorted_probs = sorted(
results["class_probabilities"].items(), key=lambda x: x[1], reverse=True
)
for class_name, prob in sorted_probs:
output_str += f"{class_name}: {prob*100:.2f}%\n"
return output_str
demo = gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil", label="Upload Image")],
outputs=[gr.Textbox(label="Classification Results")],
title="Waste Classification System",
description="""
Upload an image of waste to classify it into different categories.
The model will predict the type of waste and show confidence scores for each category.
""",
examples=(
[["example1.jpg"], ["example2.jpg"], ["example3.jpg"]]
if os.path.exists("example1.jpg")
else None
),
analytics_enabled=False,
theme="default",
)
return demo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = [
"Cardboard",
"Food Organics",
"Glass",
"Metal",
"Miscellaneous Trash",
"Paper",
"Plastic",
"Textile Trash",
"Vegetation",
]
best_model = ResNet50(num_classes=len(class_names))
best_model = best_model.to(device)
load_model(best_model, os.path.join(__file__, "..", "bjf8fp.safetensors"))
classifier = WasteClassifier(best_model, class_names, device)
demo = interface(classifier)
demo.launch(share=True)