Spaces:
Sleeping
Sleeping
File size: 2,050 Bytes
f9a674e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from .catalog import DatasetCatalog
from ldm.util import instantiate_from_config
import torch
class ConCatDataset():
def __init__(self, dataset_name_list, ROOT, which_embedder, train=True, repeats=None):
self.datasets = []
cul_previous_dataset_length = 0
offset_map = []
which_dataset = []
if repeats is None:
repeats = [1] * len(dataset_name_list)
else:
assert len(repeats) == len(dataset_name_list)
Catalog = DatasetCatalog(ROOT, which_embedder)
for dataset_idx, (dataset_name, yaml_params) in enumerate(dataset_name_list.items()):
repeat = repeats[dataset_idx]
dataset_dict = getattr(Catalog, dataset_name)
target = dataset_dict['target']
params = dataset_dict['train_params'] if train else dataset_dict['val_params']
if yaml_params is not None:
params.update(yaml_params)
dataset = instantiate_from_config( dict(target=target, params=params) )
self.datasets.append(dataset)
for _ in range(repeat):
offset_map.append( torch.ones(len(dataset))*cul_previous_dataset_length )
which_dataset.append( torch.ones(len(dataset))*dataset_idx )
cul_previous_dataset_length += len(dataset)
offset_map = torch.cat(offset_map, dim=0).long()
self.total_length = cul_previous_dataset_length
self.mapping = torch.arange(self.total_length) - offset_map
self.which_dataset = torch.cat(which_dataset, dim=0).long()
def total_images(self):
count = 0
for dataset in self.datasets:
print(dataset.total_images())
count += dataset.total_images()
return count
def __getitem__(self, idx):
dataset = self.datasets[ self.which_dataset[idx] ]
return dataset[ self.mapping[idx] ]
def __len__(self):
return self.total_length
|