|
import torch |
|
from loguru import logger |
|
from src.model import LitEfficientNet |
|
from src.dataloader import MNISTDataModule |
|
from torchmetrics.classification import Accuracy |
|
from pathlib import Path |
|
from src.utils.aws_s3_services import S3Handler |
|
|
|
|
|
logger.add("logs/test.log", rotation="1 MB", level="INFO") |
|
|
|
|
|
def infer(checkpoint_path, image): |
|
""" |
|
Perform inference on a single image using the model checkpoint. |
|
|
|
Args: |
|
checkpoint_path (str): Path to the model checkpoint. |
|
image (torch.Tensor): Image tensor to predict (shape: [1, 28, 28] for MNIST). |
|
|
|
Returns: |
|
int: Predicted class (0-9). |
|
""" |
|
logger.info(f"Loading model from checkpoint: {checkpoint_path} for inference...") |
|
if not Path(checkpoint_path).exists(): |
|
logger.error(f"Checkpoint not found: {checkpoint_path}") |
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Inference will run on device: {device}") |
|
|
|
|
|
model = LitEfficientNet.load_from_checkpoint(checkpoint_path).to(device) |
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
if image.dim() == 3: |
|
image = image.unsqueeze(0) |
|
image = image.to(device) |
|
prediction = model(image) |
|
predicted_class = torch.argmax(prediction, dim=1).item() |
|
|
|
logger.info(f"Predicted class: {predicted_class}") |
|
return predicted_class |
|
|
|
|
|
def test_model(checkpoint_path): |
|
""" |
|
Test the model using the test dataset and log metrics. |
|
|
|
Args: |
|
checkpoint_path (str): Path to the model checkpoint. |
|
|
|
Returns: |
|
float: Final test accuracy. |
|
""" |
|
logger.info(f"Loading model from checkpoint: {checkpoint_path} for testing...") |
|
if not Path(checkpoint_path).exists(): |
|
logger.error(f"Checkpoint not found: {checkpoint_path}") |
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f"Testing will run on device: {device}") |
|
|
|
|
|
model = LitEfficientNet.load_from_checkpoint(checkpoint_path).to(device) |
|
model.eval() |
|
|
|
|
|
data_module = MNISTDataModule() |
|
data_module.setup(stage="test") |
|
test_loader = data_module.test_dataloader() |
|
|
|
|
|
test_acc = Accuracy(num_classes=10, task="multiclass").to(device) |
|
|
|
|
|
logger.info("Evaluating on test dataset...") |
|
with torch.no_grad(): |
|
for images, labels in test_loader: |
|
images, labels = images.to(device), labels.to( |
|
device |
|
) |
|
outputs = model(images) |
|
test_acc.update(outputs, labels) |
|
|
|
accuracy = test_acc.compute().item() |
|
logger.info(f"Final Test Accuracy (TorchMetrics): {accuracy:.2%}") |
|
return accuracy |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
s3_handler = S3Handler(bucket_name="deep-bucket-s3") |
|
s3_handler.download_folder( |
|
"checkpoints_test", |
|
"checkpoints", |
|
) |
|
checkpoint_path = "./checkpoints/best_model.ckpt" |
|
try: |
|
|
|
test_accuracy = test_model(checkpoint_path) |
|
logger.info(f"Test completed successfully with accuracy: {test_accuracy:.2%}") |
|
|
|
|
|
logger.info("Running inference on a single test image...") |
|
dummy_image = torch.randn(1, 28, 28) |
|
predicted_class = infer(checkpoint_path, dummy_image) |
|
logger.info(f"Inference result: Predicted class {predicted_class}") |
|
except Exception as e: |
|
logger.error(f"An error occurred: {e}") |
|
|