File size: 4,607 Bytes
f7a5cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F

# ------------------------------------------------------------------------------------- #

num_frequencies = None

# ------------------------------------------------------------------------------------- #


class CharacterDataset(Dataset):
    def __init__(
        self,
        name: str,
        dataset_dir: str,
        standardize: bool,
        num_feats: int,
        num_cams: int,
        sequential: bool,
        num_frequencies: int,
        min_freq: int,
        max_freq: int,
        load_vertices: bool,
        **kwargs,
    ):
        super().__init__()
        self.modality = "char"
        self.name = name
        self.dataset_dir = Path(dataset_dir)
        self.traj_dir = self.dataset_dir / "traj"
        self.data_dir = self.dataset_dir / self.name
        self.vert_dir = self.dataset_dir / "vert_raw"
        self.center_dir = self.dataset_dir / "char_raw"

        self.filenames = None
        self.standardize = standardize
        if self.standardize:
            mean_std = kwargs["standardization"]
            self.norm_mean = torch.Tensor(mean_std["norm_mean_h"])[:, None]
            self.norm_std = torch.Tensor(mean_std["norm_std_h"])[:, None]
            self.velocity = mean_std["velocity"]

        self.num_cams = num_cams
        self.num_feats = num_feats
        self.sequential = sequential
        self.num_frequencies = num_frequencies
        self.min_freq = min_freq
        self.max_freq = max_freq

        self.load_vertices = load_vertices

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, index):
        filename = self.filenames[index]

        char_filename = filename + ".npy"
        char_path = self.data_dir / char_filename

        raw_char_feature = torch.from_numpy(np.load((char_path))).to(torch.float32)
        padding_size = self.num_cams - raw_char_feature.shape[0]
        padded_raw_char_feature = F.pad(
            raw_char_feature, (0, 0, 0, padding_size)
        ).permute(1, 0)

        center_path = self.center_dir / char_filename  # Center to offset mesh
        center_offset = torch.from_numpy(np.load(center_path)[0]).to(torch.float32)
        if self.load_vertices:
            vert_path = self.vert_dir / char_filename
            raw_verts = np.load(vert_path, allow_pickle=True)[()]
            if raw_verts["vertices"] is None:
                num_frames = raw_char_feature.shape[0]
                verts = torch.zeros((num_frames, 6890, 3), dtype=torch.float32)
                padded_verts = torch.zeros(
                    (self.num_cams, 6890, 3), dtype=torch.float32
                )
                faces = torch.zeros((13776, 3), dtype=torch.int16)
            else:
                verts = torch.from_numpy(raw_verts["vertices"]).to(torch.float32)
                verts -= center_offset
                padded_verts = F.pad(verts, (0, 0, 0, 0, 0, padding_size))
                faces = torch.from_numpy(raw_verts["faces"]).to(torch.int16)

        char_feature = raw_char_feature.clone()
        if self.velocity:
            velocity = char_feature[1:].clone() - char_feature[:-1].clone()
            char_feature = torch.cat([raw_char_feature[0][None], velocity])

        if self.standardize:
            # Normalize the first frame (orgin) and the rest (velocity) separately
            if len(self.norm_mean) == 6:
                char_feature[0] -= self.norm_mean[:3, 0].to(raw_char_feature.device)
                char_feature[0] /= self.norm_std[:3, 0].to(raw_char_feature.device)
                char_feature[1:] -= self.norm_mean[3:, 0].to(raw_char_feature.device)
                char_feature[1:] /= self.norm_std[3:, 0].to(raw_char_feature.device)
            # Normalize all in one
            else:
                char_feature -= self.norm_mean[:, 0].to(raw_char_feature.device)
                char_feature /= self.norm_std[:, 0].to(raw_char_feature.device)
        padded_char_feature = F.pad(
            char_feature,
            (0, 0, 0, self.num_cams - char_feature.shape[0]),
        )

        if self.sequential:
            padded_char_feature = padded_char_feature.permute(1, 0)
        else:
            padded_char_feature = padded_char_feature.reshape(-1)

        raw_feats = {"char_raw_feat": padded_raw_char_feature}
        if self.load_vertices:
            raw_feats["char_vertices"] = padded_verts
            raw_feats["char_faces"] = faces

        return char_filename, padded_char_feature, raw_feats