File size: 4,329 Bytes
c3d82b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from loguru import logger
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import lightning as pl
from typing import Optional
from multiprocessing import cpu_count
from sklearn.model_selection import train_test_split

# Configure Loguru to save logs to the logs/ directory
logger.add("logs/dataloader.log", rotation="1 MB", level="INFO")


class MNISTDataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size: int = 64,
        data_dir: str = "./data",
        num_workers: int = int(cpu_count()),
        train_subset_fraction: float = 0.25,  # Fraction of training data to use
    ):
        """
        Initializes the MNIST Data Module with configurations for dataloaders.

        Args:
            batch_size (int): Batch size for training, validation, and testing.
            data_dir (str): Directory to download and store the dataset.
            num_workers (int): Number of workers for data loading.
            train_subset_fraction (float): Fraction of training data to use (0.0 < fraction <= 1.0).
        """
        super().__init__()
        self.batch_size = batch_size
        self.data_dir = data_dir
        self.num_workers = num_workers
        self.train_subset_fraction = train_subset_fraction
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
        )
        logger.info(f"MNIST DataModule initialized with batch size {self.batch_size}")

    def prepare_data(self):
        """
        Downloads the MNIST dataset if not already downloaded.
        """
        datasets.MNIST(root=self.data_dir, train=True, download=True)
        datasets.MNIST(root=self.data_dir, train=False, download=True)
        logger.info("MNIST dataset downloaded.")

    def setup(self, stage: Optional[str] = None):
        """
        Set up the dataset for different stages.

        Args:
            stage (str, optional): One of "fit", "validate", "test", or "predict".
        """
        logger.info(f"Setting up data for stage: {stage}")
        if stage == "fit" or stage is None:
            full_train_dataset = datasets.MNIST(
                root=self.data_dir, train=True, transform=self.transform
            )
            train_indices, _ = train_test_split(
                range(len(full_train_dataset)),
                train_size=self.train_subset_fraction,
                random_state=42,
            )
            self.mnist_train = Subset(full_train_dataset, train_indices)

            self.mnist_val = datasets.MNIST(
                root=self.data_dir, train=False, transform=self.transform
            )
            logger.info(f"Loaded training subset: {len(self.mnist_train)} samples.")
            logger.info(f"Loaded validation data: {len(self.mnist_val)} samples.")
        if stage == "test" or stage is None:
            self.mnist_test = datasets.MNIST(
                root=self.data_dir, train=False, transform=self.transform
            )
            logger.info(f"Loaded test data: {len(self.mnist_test)} samples.")

    def train_dataloader(self) -> DataLoader:
        """
        Returns the training DataLoader.

        Returns:
            DataLoader: Training data loader.
        """
        logger.info("Creating training DataLoader...")
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self) -> DataLoader:
        """
        Returns the validation DataLoader.

        Returns:
            DataLoader: Validation data loader.
        """
        logger.info("Creating validation DataLoader...")
        return DataLoader(
            self.mnist_val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )

    def test_dataloader(self) -> DataLoader:
        """
        Returns the test DataLoader.

        Returns:
            DataLoader: Test data loader.
        """
        logger.info("Creating test DataLoader...")
        return DataLoader(
            self.mnist_test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
        )