Spaces:
Running
on
Zero
Running
on
Zero
from pathlib import Path | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
import torch.nn.functional as F | |
# ------------------------------------------------------------------------------------- # | |
num_frequencies = None | |
# ------------------------------------------------------------------------------------- # | |
class CharacterDataset(Dataset): | |
def __init__( | |
self, | |
name: str, | |
dataset_dir: str, | |
standardize: bool, | |
num_feats: int, | |
num_cams: int, | |
sequential: bool, | |
num_frequencies: int, | |
min_freq: int, | |
max_freq: int, | |
load_vertices: bool, | |
**kwargs, | |
): | |
super().__init__() | |
self.modality = "char" | |
self.name = name | |
self.dataset_dir = Path(dataset_dir) | |
self.traj_dir = self.dataset_dir / "traj" | |
self.data_dir = self.dataset_dir / self.name | |
self.vert_dir = self.dataset_dir / "vert_raw" | |
self.center_dir = self.dataset_dir / "char_raw" | |
self.filenames = None | |
self.standardize = standardize | |
if self.standardize: | |
mean_std = kwargs["standardization"] | |
self.norm_mean = torch.Tensor(mean_std["norm_mean_h"])[:, None] | |
self.norm_std = torch.Tensor(mean_std["norm_std_h"])[:, None] | |
self.velocity = mean_std["velocity"] | |
self.num_cams = num_cams | |
self.num_feats = num_feats | |
self.sequential = sequential | |
self.num_frequencies = num_frequencies | |
self.min_freq = min_freq | |
self.max_freq = max_freq | |
self.load_vertices = load_vertices | |
def __len__(self): | |
return len(self.filenames) | |
def __getitem__(self, index): | |
filename = self.filenames[index] | |
char_filename = filename + ".npy" | |
char_path = self.data_dir / char_filename | |
raw_char_feature = torch.from_numpy(np.load((char_path))).to(torch.float32) | |
padding_size = self.num_cams - raw_char_feature.shape[0] | |
padded_raw_char_feature = F.pad( | |
raw_char_feature, (0, 0, 0, padding_size) | |
).permute(1, 0) | |
center_path = self.center_dir / char_filename # Center to offset mesh | |
center_offset = torch.from_numpy(np.load(center_path)[0]).to(torch.float32) | |
if self.load_vertices: | |
vert_path = self.vert_dir / char_filename | |
raw_verts = np.load(vert_path, allow_pickle=True)[()] | |
if raw_verts["vertices"] is None: | |
num_frames = raw_char_feature.shape[0] | |
verts = torch.zeros((num_frames, 6890, 3), dtype=torch.float32) | |
padded_verts = torch.zeros( | |
(self.num_cams, 6890, 3), dtype=torch.float32 | |
) | |
faces = torch.zeros((13776, 3), dtype=torch.int16) | |
else: | |
verts = torch.from_numpy(raw_verts["vertices"]).to(torch.float32) | |
verts -= center_offset | |
padded_verts = F.pad(verts, (0, 0, 0, 0, 0, padding_size)) | |
faces = torch.from_numpy(raw_verts["faces"]).to(torch.int16) | |
char_feature = raw_char_feature.clone() | |
if self.velocity: | |
velocity = char_feature[1:].clone() - char_feature[:-1].clone() | |
char_feature = torch.cat([raw_char_feature[0][None], velocity]) | |
if self.standardize: | |
# Normalize the first frame (orgin) and the rest (velocity) separately | |
if len(self.norm_mean) == 6: | |
char_feature[0] -= self.norm_mean[:3, 0].to(raw_char_feature.device) | |
char_feature[0] /= self.norm_std[:3, 0].to(raw_char_feature.device) | |
char_feature[1:] -= self.norm_mean[3:, 0].to(raw_char_feature.device) | |
char_feature[1:] /= self.norm_std[3:, 0].to(raw_char_feature.device) | |
# Normalize all in one | |
else: | |
char_feature -= self.norm_mean[:, 0].to(raw_char_feature.device) | |
char_feature /= self.norm_std[:, 0].to(raw_char_feature.device) | |
padded_char_feature = F.pad( | |
char_feature, | |
(0, 0, 0, self.num_cams - char_feature.shape[0]), | |
) | |
if self.sequential: | |
padded_char_feature = padded_char_feature.permute(1, 0) | |
else: | |
padded_char_feature = padded_char_feature.reshape(-1) | |
raw_feats = {"char_raw_feat": padded_raw_char_feature} | |
if self.load_vertices: | |
raw_feats["char_vertices"] = padded_verts | |
raw_feats["char_faces"] = faces | |
return char_filename, padded_char_feature, raw_feats | |