File size: 3,011 Bytes
f7a5cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from copy import deepcopy as dp
from pathlib import Path

from torch.utils.data import Dataset


class MultimodalDataset(Dataset):
    def __init__(
        self,
        name,
        dataset_name,
        dataset_dir,
        trajectory,
        feature_type,
        num_rawfeats,
        num_feats,
        num_cams,
        num_cond_feats,
        standardization,
        augmentation=None,
        **modalities,
    ):
        self.dataset_dir = Path(dataset_dir)
        self.name = name
        self.dataset_name = dataset_name
        self.feature_type = feature_type
        self.num_rawfeats = num_rawfeats
        self.num_feats = num_feats
        self.num_cams = num_cams
        self.trajectory_dataset = trajectory
        self.standardization = standardization
        self.modality_datasets = modalities

        if augmentation is not None:
            self.augmentation = True
            self.augmentation_rate = augmentation.rate
            self.trajectory_dataset.set_augmentation(augmentation.trajectory)
            if hasattr(augmentation, "modalities"):
                for modality, augments in augmentation.modalities:
                    self.modality_datasets[modality].set_augmentation(augments)
        else:
            self.augmentation = False

    # --------------------------------------------------------------------------------- #

    def set_split(self, split: str, train_rate: float = 1.0):
        self.split = split

        # Get trajectory split
        self.trajectory_dataset = dp(self.trajectory_dataset).set_split(
            split, train_rate
        )
        self.root_filenames = self.trajectory_dataset.filenames

        # Get modality split
        for modality_name in self.modality_datasets.keys():
            self.modality_datasets[modality_name].filenames = self.root_filenames

        self.get_feature = self.trajectory_dataset.get_feature
        self.get_matrix = self.trajectory_dataset.get_matrix

        return self

    # --------------------------------------------------------------------------------- #

    def __getitem__(self, index):
        traj_out = self.trajectory_dataset[index]
        traj_filename, traj_feature, padding_mask, intrinsics = traj_out

        out = {
            "traj_filename": traj_filename,
            "traj_feat": traj_feature,
            "padding_mask": padding_mask,
            "intrinsics": intrinsics,
        }

        for modality_name, modality_dataset in self.modality_datasets.items():
            modality_filename, modality_feature, modality_raw = modality_dataset[index]
            assert traj_filename.split(".")[0] == modality_filename.split(".")[0]
            out[f"{modality_name}_filename"] = modality_filename
            out[f"{modality_name}_feat"] = modality_feature
            out[f"{modality_name}_raw"] = modality_raw
            out[f"{modality_name}_padding_mask"] = padding_mask

        return out

    def __len__(self):
        return len(self.trajectory_dataset)