File size: 2,953 Bytes
c3d82b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
from pathlib import Path
from loguru import logger
from src.model import LitEfficientNet
from src.utils.aws_s3_services import S3Handler
# Configure Loguru for logging
logger.add("logs/inference.log", rotation="1 MB", level="INFO")
class MNISTClassifier:
def __init__(self, checkpoint_path="./checkpoints/best_model.ckpt"):
self.checkpoint_path = checkpoint_path
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Inference will run on device: {self.device}")
# Load the model
self.model = self.load_model()
self.model.eval()
# Define transforms
self.transform = transforms.Compose(
[
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
)
self.labels = [str(i) for i in range(10)] # MNIST labels are 0-9
def load_model(self):
"""
Loads the model checkpoint for inference.
"""
if not Path(self.checkpoint_path).exists():
logger.error(f"Checkpoint not found: {self.checkpoint_path}")
raise FileNotFoundError(f"Checkpoint not found: {self.checkpoint_path}")
logger.info(f"Loading model from checkpoint: {self.checkpoint_path}")
return LitEfficientNet.load_from_checkpoint(self.checkpoint_path).to(
self.device
)
@torch.no_grad()
def predict(self, image):
"""
Perform inference on a single image.
Args:
image: Input image in PIL format.
Returns:
dict: Predicted class probabilities.
"""
if image is None:
logger.error("No image provided for prediction.")
return None
# Convert to tensor and preprocess
img_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Perform inference
output = self.model(img_tensor)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Map probabilities to labels
return {self.labels[idx]: float(prob) for idx, prob in enumerate(probabilities)}
# Instantiate the classifier
checkpoint_path = "./checkpoints/best_model.ckpt"
# Download checkpoint from S3 (if needed)
s3_handler = S3Handler(bucket_name="deep-bucket-s3")
s3_handler.download_folder(
"checkpoints_test",
"checkpoints",
)
classifier = MNISTClassifier(checkpoint_path=checkpoint_path)
# Define Gradio interface
demo = gr.Interface(
fn=classifier.predict,
inputs=gr.Image(height=160, width=160, image_mode="L", type="pil"),
outputs=gr.Label(num_top_classes=1),
title="MNIST Classifier",
description="Upload a handwritten digit image to classify it (0-9).",
)
if __name__ == "__main__":
demo.launch(share=True)
|