File size: 1,387 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
""" Compose multiple datasets in a single loader. """
import numpy as np
from copy import deepcopy
from torch.utils.data import Dataset
from .wireframe_dataset import WireframeDataset
from .holicity_dataset import HolicityDataset
class MergeDataset(Dataset):
def __init__(self, mode, config=None):
super(MergeDataset, self).__init__()
# Initialize the datasets
self._datasets = []
spec_config = deepcopy(config)
for i, d in enumerate(config["datasets"]):
spec_config["dataset_name"] = d
spec_config["gt_source_train"] = config["gt_source_train"][i]
spec_config["gt_source_test"] = config["gt_source_test"][i]
if d == "wireframe":
self._datasets.append(WireframeDataset(mode, spec_config))
elif d == "holicity":
spec_config["train_split"] = config["train_splits"][i]
self._datasets.append(HolicityDataset(mode, spec_config))
else:
raise ValueError("Unknown dataset: " + d)
self._weights = config["weights"]
def __getitem__(self, item):
dataset = self._datasets[
np.random.choice(range(len(self._datasets)), p=self._weights)
]
return dataset[np.random.randint(len(dataset))]
def __len__(self):
return np.sum([len(d) for d in self._datasets])
|