Spaces:
Running
Running
""" | |
The interface of initializing different datasets. | |
""" | |
from .synthetic_dataset import SyntheticShapes | |
from .wireframe_dataset import WireframeDataset | |
from .holicity_dataset import HolicityDataset | |
from .merge_dataset import MergeDataset | |
def get_dataset(mode="train", dataset_cfg=None): | |
""" Initialize different dataset based on a configuration. """ | |
# Check dataset config is given | |
if dataset_cfg is None: | |
raise ValueError("[Error] The dataset config is required!") | |
# Synthetic dataset | |
if dataset_cfg["dataset_name"] == "synthetic_shape": | |
dataset = SyntheticShapes( | |
mode, dataset_cfg | |
) | |
# Get the collate_fn | |
from .synthetic_dataset import synthetic_collate_fn | |
collate_fn = synthetic_collate_fn | |
# Wireframe dataset | |
elif dataset_cfg["dataset_name"] == "wireframe": | |
dataset = WireframeDataset( | |
mode, dataset_cfg | |
) | |
# Get the collate_fn | |
from .wireframe_dataset import wireframe_collate_fn | |
collate_fn = wireframe_collate_fn | |
# Holicity dataset | |
elif dataset_cfg["dataset_name"] == "holicity": | |
dataset = HolicityDataset( | |
mode, dataset_cfg | |
) | |
# Get the collate_fn | |
from .holicity_dataset import holicity_collate_fn | |
collate_fn = holicity_collate_fn | |
# Dataset merging several datasets in one | |
elif dataset_cfg["dataset_name"] == "merge": | |
dataset = MergeDataset( | |
mode, dataset_cfg | |
) | |
# Get the collate_fn | |
from .holicity_dataset import holicity_collate_fn | |
collate_fn = holicity_collate_fn | |
else: | |
raise ValueError( | |
"[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"]) | |
return dataset, collate_fn | |