|
from loguru import logger |
|
import torch |
|
from torch.utils.data import DataLoader, Subset |
|
from torchvision import datasets, transforms |
|
import lightning as pl |
|
from typing import Optional |
|
from multiprocessing import cpu_count |
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
logger.add("logs/dataloader.log", rotation="1 MB", level="INFO") |
|
|
|
|
|
class MNISTDataModule(pl.LightningDataModule): |
|
def __init__( |
|
self, |
|
batch_size: int = 64, |
|
data_dir: str = "./data", |
|
num_workers: int = int(cpu_count()), |
|
train_subset_fraction: float = 0.25, |
|
): |
|
""" |
|
Initializes the MNIST Data Module with configurations for dataloaders. |
|
|
|
Args: |
|
batch_size (int): Batch size for training, validation, and testing. |
|
data_dir (str): Directory to download and store the dataset. |
|
num_workers (int): Number of workers for data loading. |
|
train_subset_fraction (float): Fraction of training data to use (0.0 < fraction <= 1.0). |
|
""" |
|
super().__init__() |
|
self.batch_size = batch_size |
|
self.data_dir = data_dir |
|
self.num_workers = num_workers |
|
self.train_subset_fraction = train_subset_fraction |
|
self.transform = transforms.Compose( |
|
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] |
|
) |
|
logger.info(f"MNIST DataModule initialized with batch size {self.batch_size}") |
|
|
|
def prepare_data(self): |
|
""" |
|
Downloads the MNIST dataset if not already downloaded. |
|
""" |
|
datasets.MNIST(root=self.data_dir, train=True, download=True) |
|
datasets.MNIST(root=self.data_dir, train=False, download=True) |
|
logger.info("MNIST dataset downloaded.") |
|
|
|
def setup(self, stage: Optional[str] = None): |
|
""" |
|
Set up the dataset for different stages. |
|
|
|
Args: |
|
stage (str, optional): One of "fit", "validate", "test", or "predict". |
|
""" |
|
logger.info(f"Setting up data for stage: {stage}") |
|
if stage == "fit" or stage is None: |
|
full_train_dataset = datasets.MNIST( |
|
root=self.data_dir, train=True, transform=self.transform |
|
) |
|
train_indices, _ = train_test_split( |
|
range(len(full_train_dataset)), |
|
train_size=self.train_subset_fraction, |
|
random_state=42, |
|
) |
|
self.mnist_train = Subset(full_train_dataset, train_indices) |
|
|
|
self.mnist_val = datasets.MNIST( |
|
root=self.data_dir, train=False, transform=self.transform |
|
) |
|
logger.info(f"Loaded training subset: {len(self.mnist_train)} samples.") |
|
logger.info(f"Loaded validation data: {len(self.mnist_val)} samples.") |
|
if stage == "test" or stage is None: |
|
self.mnist_test = datasets.MNIST( |
|
root=self.data_dir, train=False, transform=self.transform |
|
) |
|
logger.info(f"Loaded test data: {len(self.mnist_test)} samples.") |
|
|
|
def train_dataloader(self) -> DataLoader: |
|
""" |
|
Returns the training DataLoader. |
|
|
|
Returns: |
|
DataLoader: Training data loader. |
|
""" |
|
logger.info("Creating training DataLoader...") |
|
return DataLoader( |
|
self.mnist_train, |
|
batch_size=self.batch_size, |
|
shuffle=True, |
|
num_workers=self.num_workers, |
|
) |
|
|
|
def val_dataloader(self) -> DataLoader: |
|
""" |
|
Returns the validation DataLoader. |
|
|
|
Returns: |
|
DataLoader: Validation data loader. |
|
""" |
|
logger.info("Creating validation DataLoader...") |
|
return DataLoader( |
|
self.mnist_val, |
|
batch_size=self.batch_size, |
|
shuffle=False, |
|
num_workers=self.num_workers, |
|
) |
|
|
|
def test_dataloader(self) -> DataLoader: |
|
""" |
|
Returns the test DataLoader. |
|
|
|
Returns: |
|
DataLoader: Test data loader. |
|
""" |
|
logger.info("Creating test DataLoader...") |
|
return DataLoader( |
|
self.mnist_test, |
|
batch_size=self.batch_size, |
|
shuffle=False, |
|
num_workers=self.num_workers, |
|
) |
|
|