Spaces:
Paused
Paused
import pytorch_lightning as pl | |
import torchvision | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
class CIFAR10DataDictWrapper(Dataset): | |
def __init__(self, dset): | |
super().__init__() | |
self.dset = dset | |
def __getitem__(self, i): | |
x, y = self.dset[i] | |
return {"jpg": x, "cls": y} | |
def __len__(self): | |
return len(self.dset) | |
class CIFAR10Loader(pl.LightningDataModule): | |
def __init__(self, batch_size, num_workers=0, shuffle=True): | |
super().__init__() | |
transform = transforms.Compose( | |
[transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] | |
) | |
self.batch_size = batch_size | |
self.num_workers = num_workers | |
self.shuffle = shuffle | |
self.train_dataset = CIFAR10DataDictWrapper( | |
torchvision.datasets.CIFAR10( | |
root=".data/", train=True, download=True, transform=transform | |
) | |
) | |
self.test_dataset = CIFAR10DataDictWrapper( | |
torchvision.datasets.CIFAR10( | |
root=".data/", train=False, download=True, transform=transform | |
) | |
) | |
def prepare_data(self): | |
pass | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
num_workers=self.num_workers, | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
self.test_dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
num_workers=self.num_workers, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.test_dataset, | |
batch_size=self.batch_size, | |
shuffle=self.shuffle, | |
num_workers=self.num_workers, | |
) | |