|
""" |
|
The interface of initializing different datasets. |
|
""" |
|
from .synthetic_dataset import SyntheticShapes |
|
from .wireframe_dataset import WireframeDataset |
|
from .holicity_dataset import HolicityDataset |
|
from .merge_dataset import MergeDataset |
|
|
|
|
|
def get_dataset(mode="train", dataset_cfg=None): |
|
"""Initialize different dataset based on a configuration.""" |
|
|
|
if dataset_cfg is None: |
|
raise ValueError("[Error] The dataset config is required!") |
|
|
|
|
|
if dataset_cfg["dataset_name"] == "synthetic_shape": |
|
dataset = SyntheticShapes(mode, dataset_cfg) |
|
|
|
|
|
from .synthetic_dataset import synthetic_collate_fn |
|
|
|
collate_fn = synthetic_collate_fn |
|
|
|
|
|
elif dataset_cfg["dataset_name"] == "wireframe": |
|
dataset = WireframeDataset(mode, dataset_cfg) |
|
|
|
|
|
from .wireframe_dataset import wireframe_collate_fn |
|
|
|
collate_fn = wireframe_collate_fn |
|
|
|
|
|
elif dataset_cfg["dataset_name"] == "holicity": |
|
dataset = HolicityDataset(mode, dataset_cfg) |
|
|
|
|
|
from .holicity_dataset import holicity_collate_fn |
|
|
|
collate_fn = holicity_collate_fn |
|
|
|
|
|
elif dataset_cfg["dataset_name"] == "merge": |
|
dataset = MergeDataset(mode, dataset_cfg) |
|
|
|
|
|
from .holicity_dataset import holicity_collate_fn |
|
|
|
collate_fn = holicity_collate_fn |
|
|
|
else: |
|
raise ValueError( |
|
"[Error] The dataset '%s' is not supported" % dataset_cfg["dataset_name"] |
|
) |
|
|
|
return dataset, collate_fn |
|
|