File size: 4,329 Bytes
c3d82b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from loguru import logger
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import lightning as pl
from typing import Optional
from multiprocessing import cpu_count
from sklearn.model_selection import train_test_split
# Configure Loguru to save logs to the logs/ directory
logger.add("logs/dataloader.log", rotation="1 MB", level="INFO")
class MNISTDataModule(pl.LightningDataModule):
def __init__(
self,
batch_size: int = 64,
data_dir: str = "./data",
num_workers: int = int(cpu_count()),
train_subset_fraction: float = 0.25, # Fraction of training data to use
):
"""
Initializes the MNIST Data Module with configurations for dataloaders.
Args:
batch_size (int): Batch size for training, validation, and testing.
data_dir (str): Directory to download and store the dataset.
num_workers (int): Number of workers for data loading.
train_subset_fraction (float): Fraction of training data to use (0.0 < fraction <= 1.0).
"""
super().__init__()
self.batch_size = batch_size
self.data_dir = data_dir
self.num_workers = num_workers
self.train_subset_fraction = train_subset_fraction
self.transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
logger.info(f"MNIST DataModule initialized with batch size {self.batch_size}")
def prepare_data(self):
"""
Downloads the MNIST dataset if not already downloaded.
"""
datasets.MNIST(root=self.data_dir, train=True, download=True)
datasets.MNIST(root=self.data_dir, train=False, download=True)
logger.info("MNIST dataset downloaded.")
def setup(self, stage: Optional[str] = None):
"""
Set up the dataset for different stages.
Args:
stage (str, optional): One of "fit", "validate", "test", or "predict".
"""
logger.info(f"Setting up data for stage: {stage}")
if stage == "fit" or stage is None:
full_train_dataset = datasets.MNIST(
root=self.data_dir, train=True, transform=self.transform
)
train_indices, _ = train_test_split(
range(len(full_train_dataset)),
train_size=self.train_subset_fraction,
random_state=42,
)
self.mnist_train = Subset(full_train_dataset, train_indices)
self.mnist_val = datasets.MNIST(
root=self.data_dir, train=False, transform=self.transform
)
logger.info(f"Loaded training subset: {len(self.mnist_train)} samples.")
logger.info(f"Loaded validation data: {len(self.mnist_val)} samples.")
if stage == "test" or stage is None:
self.mnist_test = datasets.MNIST(
root=self.data_dir, train=False, transform=self.transform
)
logger.info(f"Loaded test data: {len(self.mnist_test)} samples.")
def train_dataloader(self) -> DataLoader:
"""
Returns the training DataLoader.
Returns:
DataLoader: Training data loader.
"""
logger.info("Creating training DataLoader...")
return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
def val_dataloader(self) -> DataLoader:
"""
Returns the validation DataLoader.
Returns:
DataLoader: Validation data loader.
"""
logger.info("Creating validation DataLoader...")
return DataLoader(
self.mnist_val,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
def test_dataloader(self) -> DataLoader:
"""
Returns the test DataLoader.
Returns:
DataLoader: Test data loader.
"""
logger.info("Creating test DataLoader...")
return DataLoader(
self.mnist_test,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
|