File size: 3,998 Bytes
c3d82b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
from loguru import logger
from src.model import LitEfficientNet
from src.dataloader import MNISTDataModule
from torchmetrics.classification import Accuracy
from pathlib import Path
from src.utils.aws_s3_services import S3Handler

# Configure Loguru to save logs to the logs/ directory
logger.add("logs/test.log", rotation="1 MB", level="INFO")


def infer(checkpoint_path, image):
    """
    Perform inference on a single image using the model checkpoint.

    Args:
        checkpoint_path (str): Path to the model checkpoint.
        image (torch.Tensor): Image tensor to predict (shape: [1, 28, 28] for MNIST).

    Returns:
        int: Predicted class (0-9).
    """
    logger.info(f"Loading model from checkpoint: {checkpoint_path} for inference...")
    if not Path(checkpoint_path).exists():
        logger.error(f"Checkpoint not found: {checkpoint_path}")
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    # Detect device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Inference will run on device: {device}")

    # Load the model
    model = LitEfficientNet.load_from_checkpoint(checkpoint_path).to(device)
    model.eval()

    # Perform inference
    with torch.no_grad():
        if image.dim() == 3:
            image = image.unsqueeze(0)  # Add batch dimension if needed
        image = image.to(device)  # Ensure the image is on the same device as the model
        prediction = model(image)
        predicted_class = torch.argmax(prediction, dim=1).item()

    logger.info(f"Predicted class: {predicted_class}")
    return predicted_class


def test_model(checkpoint_path):
    """
    Test the model using the test dataset and log metrics.

    Args:
        checkpoint_path (str): Path to the model checkpoint.

    Returns:
        float: Final test accuracy.
    """
    logger.info(f"Loading model from checkpoint: {checkpoint_path} for testing...")
    if not Path(checkpoint_path).exists():
        logger.error(f"Checkpoint not found: {checkpoint_path}")
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")

    # Detect device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Testing will run on device: {device}")

    # Load the model
    model = LitEfficientNet.load_from_checkpoint(checkpoint_path).to(device)
    model.eval()

    # Set up data module and load test data
    data_module = MNISTDataModule()
    data_module.setup(stage="test")
    test_loader = data_module.test_dataloader()

    # Initialize accuracy metric
    test_acc = Accuracy(num_classes=10, task="multiclass").to(device)

    # Evaluate model on test data
    logger.info("Evaluating on test dataset...")
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(
                device
            )  # Move data to the same device
            outputs = model(images)
            test_acc.update(outputs, labels)

    accuracy = test_acc.compute().item()
    logger.info(f"Final Test Accuracy (TorchMetrics): {accuracy:.2%}")
    return accuracy


if __name__ == "__main__":

    # downloading from s3
    s3_handler = S3Handler(bucket_name="deep-bucket-s3")
    s3_handler.download_folder(
        "checkpoints_test",
        "checkpoints",
    )
    checkpoint_path = "./checkpoints/best_model.ckpt"
    try:
        # Perform testing
        test_accuracy = test_model(checkpoint_path)
        logger.info(f"Test completed successfully with accuracy: {test_accuracy:.2%}")

        # Example inference
        logger.info("Running inference on a single test image...")
        dummy_image = torch.randn(1, 28, 28)  # Replace with actual test image
        predicted_class = infer(checkpoint_path, dummy_image)
        logger.info(f"Inference result: Predicted class {predicted_class}")
    except Exception as e:
        logger.error(f"An error occurred: {e}")