Spaces:
Runtime error
Runtime error
import pytorch_lightning as pl | |
from torch.utils.data import Dataset | |
import webdataset as wds | |
from torch.utils.data.distributed import DistributedSampler | |
class DummyDataset(pl.LightningDataModule): | |
def __init__(self,seed): | |
super().__init__() | |
def setup(self, stage): | |
if stage in ['fit']: | |
self.train_dataset = DummyData(True) | |
self.val_dataset = DummyData(False) | |
else: | |
raise NotImplementedError | |
def train_dataloader(self): | |
return wds.WebLoader(self.train_dataset, batch_size=1, num_workers=0, shuffle=False) | |
def val_dataloader(self): | |
return wds.WebLoader(self.val_dataset, batch_size=1, num_workers=0, shuffle=False) | |
def test_dataloader(self): | |
return wds.WebLoader(DummyData(False)) | |
class DummyData(Dataset): | |
def __init__(self,is_train): | |
self.is_train=is_train | |
def __len__(self): | |
if self.is_train: | |
return 99999999 | |
else: | |
return 1 | |
def __getitem__(self, index): | |
return {} | |