Spaces:
Paused
Paused
from typing import Optional | |
import torchdata.datapipes.iter | |
import webdataset as wds | |
from omegaconf import DictConfig | |
from pytorch_lightning import LightningDataModule | |
try: | |
from sdata import create_dataset, create_dummy_dataset, create_loader | |
except ImportError as e: | |
print("#" * 100) | |
print("Datasets not yet available") | |
print("to enable, we need to add stable-datasets as a submodule") | |
print("please use ``git submodule update --init --recursive``") | |
print("and do ``pip install -e stable-datasets/`` from the root of this repo") | |
print("#" * 100) | |
exit(1) | |
class StableDataModuleFromConfig(LightningDataModule): | |
def __init__( | |
self, | |
train: DictConfig, | |
validation: Optional[DictConfig] = None, | |
test: Optional[DictConfig] = None, | |
skip_val_loader: bool = False, | |
dummy: bool = False, | |
): | |
super().__init__() | |
self.train_config = train | |
assert ( | |
"datapipeline" in self.train_config and "loader" in self.train_config | |
), "train config requires the fields `datapipeline` and `loader`" | |
self.val_config = validation | |
if not skip_val_loader: | |
if self.val_config is not None: | |
assert ( | |
"datapipeline" in self.val_config and "loader" in self.val_config | |
), "validation config requires the fields `datapipeline` and `loader`" | |
else: | |
print( | |
"Warning: No Validation datapipeline defined, using that one from training" | |
) | |
self.val_config = train | |
self.test_config = test | |
if self.test_config is not None: | |
assert ( | |
"datapipeline" in self.test_config and "loader" in self.test_config | |
), "test config requires the fields `datapipeline` and `loader`" | |
self.dummy = dummy | |
if self.dummy: | |
print("#" * 100) | |
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") | |
print("#" * 100) | |
def setup(self, stage: str) -> None: | |
print("Preparing datasets") | |
if self.dummy: | |
data_fn = create_dummy_dataset | |
else: | |
data_fn = create_dataset | |
self.train_datapipeline = data_fn(**self.train_config.datapipeline) | |
if self.val_config: | |
self.val_datapipeline = data_fn(**self.val_config.datapipeline) | |
if self.test_config: | |
self.test_datapipeline = data_fn(**self.test_config.datapipeline) | |
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: | |
loader = create_loader(self.train_datapipeline, **self.train_config.loader) | |
return loader | |
def val_dataloader(self) -> wds.DataPipeline: | |
return create_loader(self.val_datapipeline, **self.val_config.loader) | |
def test_dataloader(self) -> wds.DataPipeline: | |
return create_loader(self.test_datapipeline, **self.test_config.loader) | |