Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from evo.tools.file_interface import read_kitti_poses_file | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
from torchtyping import TensorType | |
import torch.nn.functional as F | |
from typing import Tuple | |
from utils.file_utils import load_txt | |
from utils.rotation_utils import compute_rotation_matrix_from_ortho6d | |
num_cams = None | |
# ------------------------------------------------------------------------------------- # | |
class TrajectoryDataset(Dataset): | |
def __init__( | |
self, | |
name: str, | |
set_name: str, | |
dataset_dir: str, | |
num_rawfeats: int, | |
num_feats: int, | |
num_cams: int, | |
standardize: bool, | |
**kwargs, | |
): | |
super().__init__() | |
self.name = name | |
self.set_name = set_name | |
self.dataset_dir = Path(dataset_dir) | |
if name == "relative": | |
self.data_dir = self.dataset_dir / "traj_raw" | |
self.relative_dir = self.dataset_dir / "relative" | |
else: | |
self.data_dir = self.dataset_dir / "traj" | |
self.intrinsics_dir = self.dataset_dir / "intrinsics" | |
self.num_rawfeats = num_rawfeats | |
self.num_feats = num_feats | |
self.num_cams = num_cams | |
self.augmentation = None | |
self.standardize = standardize | |
if self.standardize: | |
mean_std = kwargs["standardization"] | |
self.norm_mean = torch.Tensor(mean_std["norm_mean"]) | |
self.norm_std = torch.Tensor(mean_std["norm_std"]) | |
self.shift_mean = torch.Tensor(mean_std["shift_mean"]) | |
self.shift_std = torch.Tensor(mean_std["shift_std"]) | |
self.velocity = mean_std["velocity"] | |
# --------------------------------------------------------------------------------- # | |
def set_split(self, split: str, train_rate: float = 1.0): | |
self.split = split | |
split_path = Path(self.dataset_dir) / f"{split}_split.txt" | |
split_traj = load_txt(split_path).split("\n") | |
self.filenames = sorted(split_traj) | |
return self | |
# --------------------------------------------------------------------------------- # | |
def get_feature( | |
self, raw_matrix_trajectory: TensorType["num_cams", 4, 4] | |
) -> TensorType[9, "num_cams"]: | |
matrix_trajectory = torch.clone(raw_matrix_trajectory) | |
raw_trans = torch.clone(matrix_trajectory[:, :3, 3]) | |
if self.velocity: | |
velocity = raw_trans[1:] - raw_trans[:-1] | |
raw_trans = torch.cat([raw_trans[0][None], velocity]) | |
if self.standardize: | |
raw_trans[0] -= self.shift_mean | |
raw_trans[0] /= self.shift_std | |
raw_trans[1:] -= self.norm_mean | |
raw_trans[1:] /= self.norm_std | |
# Compute the 6D continuous rotation | |
raw_rot = matrix_trajectory[:, :3, :3] | |
rot6d = raw_rot[:, :, :2].permute(0, 2, 1).reshape(-1, 6) | |
# Stack rotation 6D and translation | |
rot6d_trajectory = torch.hstack([rot6d, raw_trans]).permute(1, 0) | |
return rot6d_trajectory | |
def get_matrix( | |
self, raw_rot6d_trajectory: TensorType[9, "num_cams"] | |
) -> TensorType["num_cams", 4, 4]: | |
rot6d_trajectory = torch.clone(raw_rot6d_trajectory) | |
device = rot6d_trajectory.device | |
num_cams = rot6d_trajectory.shape[1] | |
matrix_trajectory = torch.eye(4, device=device)[None].repeat(num_cams, 1, 1) | |
raw_trans = rot6d_trajectory[6:].permute(1, 0) | |
if self.standardize: | |
raw_trans[0] *= self.shift_std.to(device) | |
raw_trans[0] += self.shift_mean.to(device) | |
raw_trans[1:] *= self.norm_std.to(device) | |
raw_trans[1:] += self.norm_mean.to(device) | |
if self.velocity: | |
raw_trans = torch.cumsum(raw_trans, dim=0) | |
matrix_trajectory[:, :3, 3] = raw_trans | |
rot6d = rot6d_trajectory[:6].permute(1, 0) | |
raw_rot = compute_rotation_matrix_from_ortho6d(rot6d) | |
matrix_trajectory[:, :3, :3] = raw_rot | |
return matrix_trajectory | |
# --------------------------------------------------------------------------------- # | |
def __getitem__(self, index: int) -> Tuple[str, TensorType["num_cams", 4, 4]]: | |
filename = self.filenames[index] | |
trajectory_filename = filename + ".txt" | |
trajectory_path = self.data_dir / trajectory_filename | |
trajectory = read_kitti_poses_file(trajectory_path) | |
matrix_trajectory = torch.from_numpy(np.array(trajectory.poses_se3)).to( | |
torch.float32 | |
) | |
trajectory_feature = self.get_feature(matrix_trajectory) | |
padded_trajectory_feature = F.pad( | |
trajectory_feature, (0, self.num_cams - trajectory_feature.shape[1]) | |
) | |
# Padding mask: 1 for valid cams, 0 for padded cams | |
padding_mask = torch.ones((self.num_cams)) | |
padding_mask[trajectory_feature.shape[1] :] = 0 | |
intrinsics_filename = filename + ".npy" | |
intrinsics_path = self.intrinsics_dir / intrinsics_filename | |
intrinsics = np.load(intrinsics_path) | |
return ( | |
trajectory_filename, | |
padded_trajectory_feature, | |
padding_mask, | |
intrinsics | |
) | |
def __len__(self): | |
return len(self.filenames) | |