from pathlib import Path from evo.tools.file_interface import read_kitti_poses_file import numpy as np import torch from torch.utils.data import Dataset from torchtyping import TensorType import torch.nn.functional as F from typing import Tuple from utils.file_utils import load_txt from utils.rotation_utils import compute_rotation_matrix_from_ortho6d num_cams = None # ------------------------------------------------------------------------------------- # class TrajectoryDataset(Dataset): def __init__( self, name: str, set_name: str, dataset_dir: str, num_rawfeats: int, num_feats: int, num_cams: int, standardize: bool, **kwargs, ): super().__init__() self.name = name self.set_name = set_name self.dataset_dir = Path(dataset_dir) if name == "relative": self.data_dir = self.dataset_dir / "traj_raw" self.relative_dir = self.dataset_dir / "relative" else: self.data_dir = self.dataset_dir / "traj" self.intrinsics_dir = self.dataset_dir / "intrinsics" self.num_rawfeats = num_rawfeats self.num_feats = num_feats self.num_cams = num_cams self.augmentation = None self.standardize = standardize if self.standardize: mean_std = kwargs["standardization"] self.norm_mean = torch.Tensor(mean_std["norm_mean"]) self.norm_std = torch.Tensor(mean_std["norm_std"]) self.shift_mean = torch.Tensor(mean_std["shift_mean"]) self.shift_std = torch.Tensor(mean_std["shift_std"]) self.velocity = mean_std["velocity"] # --------------------------------------------------------------------------------- # def set_split(self, split: str, train_rate: float = 1.0): self.split = split split_path = Path(self.dataset_dir) / f"{split}_split.txt" split_traj = load_txt(split_path).split("\n") self.filenames = sorted(split_traj) return self # --------------------------------------------------------------------------------- # def get_feature( self, raw_matrix_trajectory: TensorType["num_cams", 4, 4] ) -> TensorType[9, "num_cams"]: matrix_trajectory = torch.clone(raw_matrix_trajectory) raw_trans = torch.clone(matrix_trajectory[:, :3, 3]) if self.velocity: velocity = raw_trans[1:] - raw_trans[:-1] raw_trans = torch.cat([raw_trans[0][None], velocity]) if self.standardize: raw_trans[0] -= self.shift_mean raw_trans[0] /= self.shift_std raw_trans[1:] -= self.norm_mean raw_trans[1:] /= self.norm_std # Compute the 6D continuous rotation raw_rot = matrix_trajectory[:, :3, :3] rot6d = raw_rot[:, :, :2].permute(0, 2, 1).reshape(-1, 6) # Stack rotation 6D and translation rot6d_trajectory = torch.hstack([rot6d, raw_trans]).permute(1, 0) return rot6d_trajectory def get_matrix( self, raw_rot6d_trajectory: TensorType[9, "num_cams"] ) -> TensorType["num_cams", 4, 4]: rot6d_trajectory = torch.clone(raw_rot6d_trajectory) device = rot6d_trajectory.device num_cams = rot6d_trajectory.shape[1] matrix_trajectory = torch.eye(4, device=device)[None].repeat(num_cams, 1, 1) raw_trans = rot6d_trajectory[6:].permute(1, 0) if self.standardize: raw_trans[0] *= self.shift_std.to(device) raw_trans[0] += self.shift_mean.to(device) raw_trans[1:] *= self.norm_std.to(device) raw_trans[1:] += self.norm_mean.to(device) if self.velocity: raw_trans = torch.cumsum(raw_trans, dim=0) matrix_trajectory[:, :3, 3] = raw_trans rot6d = rot6d_trajectory[:6].permute(1, 0) raw_rot = compute_rotation_matrix_from_ortho6d(rot6d) matrix_trajectory[:, :3, :3] = raw_rot return matrix_trajectory # --------------------------------------------------------------------------------- # def __getitem__(self, index: int) -> Tuple[str, TensorType["num_cams", 4, 4]]: filename = self.filenames[index] trajectory_filename = filename + ".txt" trajectory_path = self.data_dir / trajectory_filename trajectory = read_kitti_poses_file(trajectory_path) matrix_trajectory = torch.from_numpy(np.array(trajectory.poses_se3)).to( torch.float32 ) trajectory_feature = self.get_feature(matrix_trajectory) padded_trajectory_feature = F.pad( trajectory_feature, (0, self.num_cams - trajectory_feature.shape[1]) ) # Padding mask: 1 for valid cams, 0 for padded cams padding_mask = torch.ones((self.num_cams)) padding_mask[trajectory_feature.shape[1] :] = 0 intrinsics_filename = filename + ".npy" intrinsics_path = self.intrinsics_dir / intrinsics_filename intrinsics = np.load(intrinsics_path) return ( trajectory_filename, padded_trajectory_feature, padding_mask, intrinsics ) def __len__(self): return len(self.filenames)