EnlightenGAN / data /custom_dataset_data_loader.py
HenryGong's picture
Upload 84 files
aba0e05 verified
raw
history blame contribute delete
No virus
1.68 kB
import torch.utils.data
from data.base_data_loader import BaseDataLoader
def CreateDataset(opt):
dataset = None
if opt.dataset_mode == 'aligned':
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
elif opt.dataset_mode == 'unaligned':
from data.unaligned_dataset import UnalignedDataset
dataset = UnalignedDataset()
elif opt.dataset_mode == 'unaligned_random_crop':
from data.unaligned_random_crop import UnalignedDataset
dataset = UnalignedDataset()
elif opt.dataset_mode == 'pair':
from data.pair_dataset import PairDataset
dataset = PairDataset()
elif opt.dataset_mode == 'syn':
from data.syn_dataset import PairDataset
dataset = PairDataset()
elif opt.dataset_mode == 'single':
from data.single_dataset import SingleDataset
dataset = SingleDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)