File size: 1,387 Bytes
404d2af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b973ee
 
 
 
404d2af
 
 
8b973ee
404d2af
 
8b973ee
 
 
404d2af
 
8b973ee
 
 
404d2af
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
""" Compose multiple datasets in a single loader. """

import numpy as np
from copy import deepcopy
from torch.utils.data import Dataset

from .wireframe_dataset import WireframeDataset
from .holicity_dataset import HolicityDataset


class MergeDataset(Dataset):
    def __init__(self, mode, config=None):
        super(MergeDataset, self).__init__()
        # Initialize the datasets
        self._datasets = []
        spec_config = deepcopy(config)
        for i, d in enumerate(config["datasets"]):
            spec_config["dataset_name"] = d
            spec_config["gt_source_train"] = config["gt_source_train"][i]
            spec_config["gt_source_test"] = config["gt_source_test"][i]
            if d == "wireframe":
                self._datasets.append(WireframeDataset(mode, spec_config))
            elif d == "holicity":
                spec_config["train_split"] = config["train_splits"][i]
                self._datasets.append(HolicityDataset(mode, spec_config))
            else:
                raise ValueError("Unknown dataset: " + d)

        self._weights = config["weights"]

    def __getitem__(self, item):
        dataset = self._datasets[
            np.random.choice(range(len(self._datasets)), p=self._weights)
        ]
        return dataset[np.random.randint(len(dataset))]

    def __len__(self):
        return np.sum([len(d) for d in self._datasets])