from pathlib import Path import numpy as np import torch from torch.utils.data import Dataset import torch.nn.functional as F # ------------------------------------------------------------------------------------- # num_frequencies = None # ------------------------------------------------------------------------------------- # class CharacterDataset(Dataset): def __init__( self, name: str, dataset_dir: str, standardize: bool, num_feats: int, num_cams: int, sequential: bool, num_frequencies: int, min_freq: int, max_freq: int, load_vertices: bool, **kwargs, ): super().__init__() self.modality = "char" self.name = name self.dataset_dir = Path(dataset_dir) self.traj_dir = self.dataset_dir / "traj" self.data_dir = self.dataset_dir / self.name self.vert_dir = self.dataset_dir / "vert_raw" self.center_dir = self.dataset_dir / "char_raw" self.filenames = None self.standardize = standardize if self.standardize: mean_std = kwargs["standardization"] self.norm_mean = torch.Tensor(mean_std["norm_mean_h"])[:, None] self.norm_std = torch.Tensor(mean_std["norm_std_h"])[:, None] self.velocity = mean_std["velocity"] self.num_cams = num_cams self.num_feats = num_feats self.sequential = sequential self.num_frequencies = num_frequencies self.min_freq = min_freq self.max_freq = max_freq self.load_vertices = load_vertices def __len__(self): return len(self.filenames) def __getitem__(self, index): filename = self.filenames[index] char_filename = filename + ".npy" char_path = self.data_dir / char_filename raw_char_feature = torch.from_numpy(np.load((char_path))).to(torch.float32) padding_size = self.num_cams - raw_char_feature.shape[0] padded_raw_char_feature = F.pad( raw_char_feature, (0, 0, 0, padding_size) ).permute(1, 0) center_path = self.center_dir / char_filename # Center to offset mesh center_offset = torch.from_numpy(np.load(center_path)[0]).to(torch.float32) if self.load_vertices: vert_path = self.vert_dir / char_filename raw_verts = np.load(vert_path, allow_pickle=True)[()] if raw_verts["vertices"] is None: num_frames = raw_char_feature.shape[0] verts = torch.zeros((num_frames, 6890, 3), dtype=torch.float32) padded_verts = torch.zeros( (self.num_cams, 6890, 3), dtype=torch.float32 ) faces = torch.zeros((13776, 3), dtype=torch.int16) else: verts = torch.from_numpy(raw_verts["vertices"]).to(torch.float32) verts -= center_offset padded_verts = F.pad(verts, (0, 0, 0, 0, 0, padding_size)) faces = torch.from_numpy(raw_verts["faces"]).to(torch.int16) char_feature = raw_char_feature.clone() if self.velocity: velocity = char_feature[1:].clone() - char_feature[:-1].clone() char_feature = torch.cat([raw_char_feature[0][None], velocity]) if self.standardize: # Normalize the first frame (orgin) and the rest (velocity) separately if len(self.norm_mean) == 6: char_feature[0] -= self.norm_mean[:3, 0].to(raw_char_feature.device) char_feature[0] /= self.norm_std[:3, 0].to(raw_char_feature.device) char_feature[1:] -= self.norm_mean[3:, 0].to(raw_char_feature.device) char_feature[1:] /= self.norm_std[3:, 0].to(raw_char_feature.device) # Normalize all in one else: char_feature -= self.norm_mean[:, 0].to(raw_char_feature.device) char_feature /= self.norm_std[:, 0].to(raw_char_feature.device) padded_char_feature = F.pad( char_feature, (0, 0, 0, self.num_cams - char_feature.shape[0]), ) if self.sequential: padded_char_feature = padded_char_feature.permute(1, 0) else: padded_char_feature = padded_char_feature.reshape(-1) raw_feats = {"char_raw_feat": padded_raw_char_feature} if self.load_vertices: raw_feats["char_vertices"] = padded_verts raw_feats["char_faces"] = faces return char_filename, padded_char_feature, raw_feats