robin-courant's picture
Add app
f7a5cb1 verified
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
# ------------------------------------------------------------------------------------- #
num_frequencies = None
# ------------------------------------------------------------------------------------- #
class CharacterDataset(Dataset):
def __init__(
self,
name: str,
dataset_dir: str,
standardize: bool,
num_feats: int,
num_cams: int,
sequential: bool,
num_frequencies: int,
min_freq: int,
max_freq: int,
load_vertices: bool,
**kwargs,
):
super().__init__()
self.modality = "char"
self.name = name
self.dataset_dir = Path(dataset_dir)
self.traj_dir = self.dataset_dir / "traj"
self.data_dir = self.dataset_dir / self.name
self.vert_dir = self.dataset_dir / "vert_raw"
self.center_dir = self.dataset_dir / "char_raw"
self.filenames = None
self.standardize = standardize
if self.standardize:
mean_std = kwargs["standardization"]
self.norm_mean = torch.Tensor(mean_std["norm_mean_h"])[:, None]
self.norm_std = torch.Tensor(mean_std["norm_std_h"])[:, None]
self.velocity = mean_std["velocity"]
self.num_cams = num_cams
self.num_feats = num_feats
self.sequential = sequential
self.num_frequencies = num_frequencies
self.min_freq = min_freq
self.max_freq = max_freq
self.load_vertices = load_vertices
def __len__(self):
return len(self.filenames)
def __getitem__(self, index):
filename = self.filenames[index]
char_filename = filename + ".npy"
char_path = self.data_dir / char_filename
raw_char_feature = torch.from_numpy(np.load((char_path))).to(torch.float32)
padding_size = self.num_cams - raw_char_feature.shape[0]
padded_raw_char_feature = F.pad(
raw_char_feature, (0, 0, 0, padding_size)
).permute(1, 0)
center_path = self.center_dir / char_filename # Center to offset mesh
center_offset = torch.from_numpy(np.load(center_path)[0]).to(torch.float32)
if self.load_vertices:
vert_path = self.vert_dir / char_filename
raw_verts = np.load(vert_path, allow_pickle=True)[()]
if raw_verts["vertices"] is None:
num_frames = raw_char_feature.shape[0]
verts = torch.zeros((num_frames, 6890, 3), dtype=torch.float32)
padded_verts = torch.zeros(
(self.num_cams, 6890, 3), dtype=torch.float32
)
faces = torch.zeros((13776, 3), dtype=torch.int16)
else:
verts = torch.from_numpy(raw_verts["vertices"]).to(torch.float32)
verts -= center_offset
padded_verts = F.pad(verts, (0, 0, 0, 0, 0, padding_size))
faces = torch.from_numpy(raw_verts["faces"]).to(torch.int16)
char_feature = raw_char_feature.clone()
if self.velocity:
velocity = char_feature[1:].clone() - char_feature[:-1].clone()
char_feature = torch.cat([raw_char_feature[0][None], velocity])
if self.standardize:
# Normalize the first frame (orgin) and the rest (velocity) separately
if len(self.norm_mean) == 6:
char_feature[0] -= self.norm_mean[:3, 0].to(raw_char_feature.device)
char_feature[0] /= self.norm_std[:3, 0].to(raw_char_feature.device)
char_feature[1:] -= self.norm_mean[3:, 0].to(raw_char_feature.device)
char_feature[1:] /= self.norm_std[3:, 0].to(raw_char_feature.device)
# Normalize all in one
else:
char_feature -= self.norm_mean[:, 0].to(raw_char_feature.device)
char_feature /= self.norm_std[:, 0].to(raw_char_feature.device)
padded_char_feature = F.pad(
char_feature,
(0, 0, 0, self.num_cams - char_feature.shape[0]),
)
if self.sequential:
padded_char_feature = padded_char_feature.permute(1, 0)
else:
padded_char_feature = padded_char_feature.reshape(-1)
raw_feats = {"char_raw_feat": padded_raw_char_feature}
if self.load_vertices:
raw_feats["char_vertices"] = padded_verts
raw_feats["char_faces"] = faces
return char_filename, padded_char_feature, raw_feats