File size: 2,953 Bytes
c3d82b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
from pathlib import Path
from loguru import logger
from src.model import LitEfficientNet
from src.utils.aws_s3_services import S3Handler

# Configure Loguru for logging
logger.add("logs/inference.log", rotation="1 MB", level="INFO")


class MNISTClassifier:
    def __init__(self, checkpoint_path="./checkpoints/best_model.ckpt"):
        self.checkpoint_path = checkpoint_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Inference will run on device: {self.device}")

        # Load the model
        self.model = self.load_model()
        self.model.eval()

        # Define transforms
        self.transform = transforms.Compose(
            [
                transforms.Resize((28, 28)),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )
        self.labels = [str(i) for i in range(10)]  # MNIST labels are 0-9

    def load_model(self):
        """
        Loads the model checkpoint for inference.
        """
        if not Path(self.checkpoint_path).exists():
            logger.error(f"Checkpoint not found: {self.checkpoint_path}")
            raise FileNotFoundError(f"Checkpoint not found: {self.checkpoint_path}")

        logger.info(f"Loading model from checkpoint: {self.checkpoint_path}")
        return LitEfficientNet.load_from_checkpoint(self.checkpoint_path).to(
            self.device
        )

    @torch.no_grad()
    def predict(self, image):
        """
        Perform inference on a single image.

        Args:
            image: Input image in PIL format.

        Returns:
            dict: Predicted class probabilities.
        """
        if image is None:
            logger.error("No image provided for prediction.")
            return None

        # Convert to tensor and preprocess
        img_tensor = self.transform(image).unsqueeze(0).to(self.device)

        # Perform inference
        output = self.model(img_tensor)
        probabilities = torch.nn.functional.softmax(output[0], dim=0)

        # Map probabilities to labels
        return {self.labels[idx]: float(prob) for idx, prob in enumerate(probabilities)}


# Instantiate the classifier
checkpoint_path = "./checkpoints/best_model.ckpt"

# Download checkpoint from S3 (if needed)
s3_handler = S3Handler(bucket_name="deep-bucket-s3")
s3_handler.download_folder(
    "checkpoints_test",
    "checkpoints",
)

classifier = MNISTClassifier(checkpoint_path=checkpoint_path)

# Define Gradio interface
demo = gr.Interface(
    fn=classifier.predict,
    inputs=gr.Image(height=160, width=160, image_mode="L", type="pil"),
    outputs=gr.Label(num_top_classes=1),
    title="MNIST Classifier",
    description="Upload a handwritten digit image to classify it (0-9).",
)

if __name__ == "__main__":
    demo.launch(share=True)