SyncDreamer / renderer /dummy_dataset.py
liuyuan-pal's picture
init
8bb8404
raw
history blame
1.05 kB
import pytorch_lightning as pl
from torch.utils.data import Dataset
import webdataset as wds
from torch.utils.data.distributed import DistributedSampler
class DummyDataset(pl.LightningDataModule):
def __init__(self,seed):
super().__init__()
def setup(self, stage):
if stage in ['fit']:
self.train_dataset = DummyData(True)
self.val_dataset = DummyData(False)
else:
raise NotImplementedError
def train_dataloader(self):
return wds.WebLoader(self.train_dataset, batch_size=1, num_workers=0, shuffle=False)
def val_dataloader(self):
return wds.WebLoader(self.val_dataset, batch_size=1, num_workers=0, shuffle=False)
def test_dataloader(self):
return wds.WebLoader(DummyData(False))
class DummyData(Dataset):
def __init__(self,is_train):
self.is_train=is_train
def __len__(self):
if self.is_train:
return 99999999
else:
return 1
def __getitem__(self, index):
return {}