youngseng's picture
Upload 187 files
da855ff
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
class SGDataset(Dataset):
def __init__(
self,
path_data_definition,
path_processed_data,
window,
style_encoding_type,
example_window_length,
):
"""PyTorch Dataset Instance
Args:
path_data_definition : Path to data_definition file
path_processed_data : Path to processed_data npz file
window : Length of the input-output slice
style_encoding_type : "label" or "example"
example_window_length : Length of example window
"""
with open(path_data_definition, "r") as f:
details = json.load(f)
self.details = details
self.njoints = len(details["bone_names"])
self.nlabels = len(details["label_names"])
self.label_names = details["label_names"]
self.bone_names = details["bone_names"]
self.parents = torch.LongTensor(details["parents"])
self.dt = details["dt"]
self.window = window
self.style_encoding_type = style_encoding_type
self.example_window_length = example_window_length
# Load Data
processed_data = np.load(path_processed_data)
self.ranges_train = processed_data["ranges_train"]
self.ranges_valid = processed_data["ranges_valid"]
self.ranges_train_labels = processed_data["ranges_train_labels"]
self.ranges_valid_labels = processed_data["ranges_valid_labels"]
self.X_audio_features = torch.as_tensor(
processed_data["X_audio_features"], dtype=torch.float32
)
self.Y_root_pos = torch.as_tensor(processed_data["Y_root_pos"], dtype=torch.float32)
self.Y_root_rot = torch.as_tensor(processed_data["Y_root_rot"], dtype=torch.float32)
self.Y_root_vel = torch.as_tensor(processed_data["Y_root_vel"], dtype=torch.float32)
self.Y_root_vrt = torch.as_tensor(processed_data["Y_root_vrt"], dtype=torch.float32)
self.Y_lpos = torch.as_tensor(processed_data["Y_lpos"], dtype=torch.float32)
self.Y_ltxy = torch.as_tensor(processed_data["Y_ltxy"], dtype=torch.float32)
self.Y_lvel = torch.as_tensor(processed_data["Y_lvel"], dtype=torch.float32)
self.Y_lvrt = torch.as_tensor(processed_data["Y_lvrt"], dtype=torch.float32)
self.Y_gaze_pos = torch.as_tensor(processed_data["Y_gaze_pos"], dtype=torch.float32)
self.audio_input_mean = torch.as_tensor(
processed_data["audio_input_mean"], dtype=torch.float32
)
self.audio_input_std = torch.as_tensor(
processed_data["audio_input_std"], dtype=torch.float32
)
self.anim_input_mean = torch.as_tensor(
processed_data["anim_input_mean"], dtype=torch.float32
)
self.anim_input_std = torch.as_tensor(processed_data["anim_input_std"], dtype=torch.float32)
self.anim_output_mean = torch.as_tensor(
processed_data["anim_output_mean"], dtype=torch.float32
)
self.anim_output_std = torch.as_tensor(
processed_data["anim_output_std"], dtype=torch.float32
)
# Build Windows
R = []
L = []
S = []
for sample_number, ((range_start, range_end), range_label) in enumerate(
zip(self.ranges_train, self.ranges_train_labels)
):
one_hot_label = np.zeros(self.nlabels, dtype=np.float32)
one_hot_label[range_label] = 1.0
for ri in range(range_start, range_end - window):
R.append(np.arange(ri, ri + window))
L.append(one_hot_label)
S.append(sample_number)
self.R = torch.as_tensor(np.array(R), dtype=torch.long)
self.L = torch.as_tensor(np.array(L), dtype=torch.float32)
self.S = torch.as_tensor(S, dtype=torch.short)
# self.get_stats()
@property
def example_window_length(self):
return self._example_window_length
@example_window_length.setter
def example_window_length(self, a):
self._example_window_length = a
def __len__(self):
return len(self.R)
def __getitem__(self, index):
# Extract Windows
Rwindow = self.R[index]
Rwindow = Rwindow.contiguous()
# Extract Labels
Rlabel = self.L[index]
# Get Corresponding Ranges for Style Encoding
RInd = self.S[index]
sample_range = self.ranges_train[RInd]
# Extract Audio
W_audio_features = self.X_audio_features[Rwindow]
# Extract Animation
W_root_pos = self.Y_root_pos[Rwindow]
W_root_rot = self.Y_root_rot[Rwindow]
W_root_vel = self.Y_root_vel[Rwindow]
W_root_vrt = self.Y_root_vrt[Rwindow]
W_lpos = self.Y_lpos[Rwindow]
W_ltxy = self.Y_ltxy[Rwindow]
W_lvel = self.Y_lvel[Rwindow]
W_lvrt = self.Y_lvrt[Rwindow]
W_gaze_pos = self.Y_gaze_pos[Rwindow]
if self.style_encoding_type == "label":
style = Rlabel
elif self.style_encoding_type == "example":
style = self.get_example(Rwindow, sample_range, self.example_window_length)
return (
W_audio_features,
W_root_pos,
W_root_rot,
W_root_vel,
W_root_vrt,
W_lpos,
W_ltxy,
W_lvel,
W_lvrt,
W_gaze_pos,
style,
)
def get_shapes(self):
num_audio_features = self.X_audio_features.shape[1]
pose_input_size = len(self.anim_input_std)
pose_output_size = len(self.anim_output_std)
dimensions = dict(
num_audio_features=num_audio_features,
pose_input_size=pose_input_size,
pose_output_size=pose_output_size,
)
return dimensions
def get_means_stds(self, device):
return (
self.audio_input_mean.to(device),
self.audio_input_std.to(device),
self.anim_input_mean.to(device),
self.anim_input_std.to(device),
self.anim_output_mean.to(device),
self.anim_output_std.to(device),
)
def get_example(
self, Rwindow, sample_range, example_window_length,
):
ext_window = (example_window_length - self.window) // 2
ws = min(ext_window, Rwindow[0] - sample_range[0])
we = min(ext_window, sample_range[1] - Rwindow[-1])
s_ext = ws + ext_window - we
w_ext = we + ext_window - ws
start = max(Rwindow[0] - s_ext, sample_range[0])
end = min(Rwindow[-1] + w_ext, sample_range[1]) + 1
end = min(end, len(self.Y_root_vel))
S_root_vel = self.Y_root_vel[start:end].reshape(end - start, -1)
S_root_vrt = self.Y_root_vrt[start:end].reshape(end - start, -1)
S_lpos = self.Y_lpos[start:end].reshape(end - start, -1)
S_ltxy = self.Y_ltxy[start:end].reshape(end - start, -1)
S_lvel = self.Y_lvel[start:end].reshape(end - start, -1)
S_lvrt = self.Y_lvrt[start:end].reshape(end - start, -1)
example_feature_vec = torch.cat(
[S_root_vel, S_root_vrt, S_lpos, S_ltxy, S_lvel, S_lvrt, torch.zeros_like(S_root_vel), ],
dim=1,
)
curr_len = len(example_feature_vec)
if curr_len < example_window_length:
example_feature_vec = torch.cat(
[example_feature_vec, example_feature_vec[-example_window_length + curr_len:]],
dim=0,
)
return example_feature_vec
def get_sample(self, dataset, length=None, range_index=None):
if dataset == "train":
if range_index is None:
range_index = np.random.randint(len(self.ranges_train))
(s, e), label = self.ranges_train[range_index], self.ranges_train_labels[range_index]
elif dataset == "valid":
if range_index is None:
range_index = np.random.randint(len(self.ranges_valid))
(s, e), label = self.ranges_valid[range_index], self.ranges_valid_labels[range_index]
if length is not None:
e = min(s + length * 60, e)
return (
self.X_audio_features[s:e][np.newaxis],
self.Y_root_pos[s:e][np.newaxis],
self.Y_root_rot[s:e][np.newaxis],
self.Y_root_vel[s:e][np.newaxis],
self.Y_root_vrt[s:e][np.newaxis],
self.Y_lpos[s:e][np.newaxis],
self.Y_ltxy[s:e][np.newaxis],
self.Y_lvel[s:e][np.newaxis],
self.Y_lvrt[s:e][np.newaxis],
self.Y_gaze_pos[s:e][np.newaxis],
label,
[s, e],
range_index,
)
def get_stats(self):
from rich.console import Console
from rich.table import Table
console = Console(record=True)
# Style infos
df = pd.DataFrame()
df["Dataset"] = ["Train", "Validation", "Total"]
pd.set_option("display.max_rows", None, "display.max_columns", None)
table = Table(title="Data Info", show_lines=True, row_styles=["magenta"])
table.add_column("Dataset")
data_len = 0
for i in range(self.nlabels):
ind_mask = self.ranges_train_labels == i
ranges = self.ranges_train[ind_mask]
num_train_frames = (
np.sum(ranges[:, 1] - ranges[:, 0]) / 2
) # It is divided by two as we have mirrored versions too
ind_mask = self.ranges_valid_labels == i
ranges = self.ranges_valid[ind_mask]
num_valid_frames = np.sum(ranges[:, 1] - ranges[:, 0]) / 2
total = num_train_frames + num_valid_frames
df[self.label_names[i]] = [
f"{num_train_frames} frames - {num_train_frames / 60:.1f} secs",
f"{num_valid_frames} frames - {num_valid_frames / 60:.1f} secs",
f"{total} frames - {total / 60:.1f} secs",
]
table.add_column(self.label_names[i])
data_len += total
for i in range(3):
table.add_row(*list(df.iloc[i]))
console.print(table)
dimensions = self.get_shapes()
console.print(f"Total length of dataset is {data_len} frames - {data_len / 60:.1f} seconds")
console.print("Num features: ", dimensions)