from loguru import logger import torch from torch.utils.data import DataLoader, Subset from torchvision import datasets, transforms import lightning as pl from typing import Optional from multiprocessing import cpu_count from sklearn.model_selection import train_test_split # Configure Loguru to save logs to the logs/ directory logger.add("logs/dataloader.log", rotation="1 MB", level="INFO") class MNISTDataModule(pl.LightningDataModule): def __init__( self, batch_size: int = 64, data_dir: str = "./data", num_workers: int = int(cpu_count()), train_subset_fraction: float = 0.25, # Fraction of training data to use ): """ Initializes the MNIST Data Module with configurations for dataloaders. Args: batch_size (int): Batch size for training, validation, and testing. data_dir (str): Directory to download and store the dataset. num_workers (int): Number of workers for data loading. train_subset_fraction (float): Fraction of training data to use (0.0 < fraction <= 1.0). """ super().__init__() self.batch_size = batch_size self.data_dir = data_dir self.num_workers = num_workers self.train_subset_fraction = train_subset_fraction self.transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] ) logger.info(f"MNIST DataModule initialized with batch size {self.batch_size}") def prepare_data(self): """ Downloads the MNIST dataset if not already downloaded. """ datasets.MNIST(root=self.data_dir, train=True, download=True) datasets.MNIST(root=self.data_dir, train=False, download=True) logger.info("MNIST dataset downloaded.") def setup(self, stage: Optional[str] = None): """ Set up the dataset for different stages. Args: stage (str, optional): One of "fit", "validate", "test", or "predict". """ logger.info(f"Setting up data for stage: {stage}") if stage == "fit" or stage is None: full_train_dataset = datasets.MNIST( root=self.data_dir, train=True, transform=self.transform ) train_indices, _ = train_test_split( range(len(full_train_dataset)), train_size=self.train_subset_fraction, random_state=42, ) self.mnist_train = Subset(full_train_dataset, train_indices) self.mnist_val = datasets.MNIST( root=self.data_dir, train=False, transform=self.transform ) logger.info(f"Loaded training subset: {len(self.mnist_train)} samples.") logger.info(f"Loaded validation data: {len(self.mnist_val)} samples.") if stage == "test" or stage is None: self.mnist_test = datasets.MNIST( root=self.data_dir, train=False, transform=self.transform ) logger.info(f"Loaded test data: {len(self.mnist_test)} samples.") def train_dataloader(self) -> DataLoader: """ Returns the training DataLoader. Returns: DataLoader: Training data loader. """ logger.info("Creating training DataLoader...") return DataLoader( self.mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, ) def val_dataloader(self) -> DataLoader: """ Returns the validation DataLoader. Returns: DataLoader: Validation data loader. """ logger.info("Creating validation DataLoader...") return DataLoader( self.mnist_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, ) def test_dataloader(self) -> DataLoader: """ Returns the test DataLoader. Returns: DataLoader: Test data loader. """ logger.info("Creating test DataLoader...") return DataLoader( self.mnist_test, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, )