Spaces:
Runtime error
Runtime error
import torch | |
from loguru import logger | |
from src.model import LitEfficientNet | |
from src.dataloader import MNISTDataModule | |
from torchmetrics.classification import Accuracy | |
from pathlib import Path | |
from src.utils.aws_s3_services import S3Handler | |
# Configure Loguru to save logs to the logs/ directory | |
logger.add("logs/test.log", rotation="1 MB", level="INFO") | |
def infer(checkpoint_path, image): | |
""" | |
Perform inference on a single image using the model checkpoint. | |
Args: | |
checkpoint_path (str): Path to the model checkpoint. | |
image (torch.Tensor): Image tensor to predict (shape: [1, 28, 28] for MNIST). | |
Returns: | |
int: Predicted class (0-9). | |
""" | |
logger.info(f"Loading model from checkpoint: {checkpoint_path} for inference...") | |
if not Path(checkpoint_path).exists(): | |
logger.error(f"Checkpoint not found: {checkpoint_path}") | |
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
# Detect device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Inference will run on device: {device}") | |
# Load the model | |
model = LitEfficientNet.load_from_checkpoint(checkpoint_path).to(device) | |
model.eval() | |
# Perform inference | |
with torch.no_grad(): | |
if image.dim() == 3: | |
image = image.unsqueeze(0) # Add batch dimension if needed | |
image = image.to(device) # Ensure the image is on the same device as the model | |
prediction = model(image) | |
predicted_class = torch.argmax(prediction, dim=1).item() | |
logger.info(f"Predicted class: {predicted_class}") | |
return predicted_class | |
def test_model(checkpoint_path): | |
""" | |
Test the model using the test dataset and log metrics. | |
Args: | |
checkpoint_path (str): Path to the model checkpoint. | |
Returns: | |
float: Final test accuracy. | |
""" | |
logger.info(f"Loading model from checkpoint: {checkpoint_path} for testing...") | |
if not Path(checkpoint_path).exists(): | |
logger.error(f"Checkpoint not found: {checkpoint_path}") | |
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
# Detect device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Testing will run on device: {device}") | |
# Load the model | |
model = LitEfficientNet.load_from_checkpoint(checkpoint_path).to(device) | |
model.eval() | |
# Set up data module and load test data | |
data_module = MNISTDataModule() | |
data_module.setup(stage="test") | |
test_loader = data_module.test_dataloader() | |
# Initialize accuracy metric | |
test_acc = Accuracy(num_classes=10, task="multiclass").to(device) | |
# Evaluate model on test data | |
logger.info("Evaluating on test dataset...") | |
with torch.no_grad(): | |
for images, labels in test_loader: | |
images, labels = images.to(device), labels.to( | |
device | |
) # Move data to the same device | |
outputs = model(images) | |
test_acc.update(outputs, labels) | |
accuracy = test_acc.compute().item() | |
logger.info(f"Final Test Accuracy (TorchMetrics): {accuracy:.2%}") | |
return accuracy | |
if __name__ == "__main__": | |
# downloading from s3 | |
s3_handler = S3Handler(bucket_name="deep-bucket-s3") | |
s3_handler.download_folder( | |
"checkpoints_test", | |
"checkpoints", | |
) | |
checkpoint_path = "./checkpoints/best_model.ckpt" | |
try: | |
# Perform testing | |
test_accuracy = test_model(checkpoint_path) | |
logger.info(f"Test completed successfully with accuracy: {test_accuracy:.2%}") | |
# Example inference | |
logger.info("Running inference on a single test image...") | |
dummy_image = torch.randn(1, 28, 28) # Replace with actual test image | |
predicted_class = infer(checkpoint_path, dummy_image) | |
logger.info(f"Inference result: Predicted class {predicted_class}") | |
except Exception as e: | |
logger.error(f"An error occurred: {e}") | |