from os import path as osp from typing import Dict from unicodedata import name import numpy as np import torch import torch.utils as utils from numpy.linalg import inv from src.utils.dataset import ( read_scannet_gray, read_scannet_depth, read_scannet_pose, read_scannet_intrinsic ) class ScanNetDataset(utils.data.Dataset): def __init__(self, root_dir, npz_path, intrinsic_path, mode='train', min_overlap_score=0.4, augment_fn=None, pose_dir=None, img_resize=None, fp16=False, **kwargs): """Manage one scene of ScanNet Dataset. Args: root_dir (str): ScanNet root directory that contains scene folders. npz_path (str): {scene_id}.npz path. This contains image pair information of a scene. intrinsic_path (str): path to depth-camera intrinsic file. mode (str): options are ['train', 'val', 'test']. augment_fn (callable, optional): augments images with pre-defined visual effects. pose_dir (str): ScanNet root directory that contains all poses. (we use a separate (optional) pose_dir since we store images and poses separately.) """ super().__init__() self.root_dir = root_dir self.pose_dir = pose_dir if pose_dir is not None else root_dir self.mode = mode # prepare data_names, intrinsics and extrinsics(T) with np.load(npz_path) as data: self.data_names = data['name'] if 'score' in data.keys() and mode not in ['val' or 'test']: kept_mask = data['score'] > min_overlap_score self.data_names = self.data_names[kept_mask] self.intrinsics = dict(np.load(intrinsic_path)) # for training LoFTR self.augment_fn = augment_fn if mode == 'train' else None self.fp16 = fp16 self.img_resize = img_resize def __len__(self): return len(self.data_names) def _read_abs_pose(self, scene_name, name): pth = osp.join(self.pose_dir, scene_name, 'pose', f'{name}.txt') return read_scannet_pose(pth) def _compute_rel_pose(self, scene_name, name0, name1): pose0 = self._read_abs_pose(scene_name, name0) pose1 = self._read_abs_pose(scene_name, name1) return np.matmul(pose1, inv(pose0)) # (4, 4) def __getitem__(self, idx): data_name = self.data_names[idx] scene_name, scene_sub_name, stem_name_0, stem_name_1 = data_name scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}' # read the grayscale image which will be resized to (1, 480, 640) img_name0 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_0}.jpg') img_name1 = osp.join(self.root_dir, scene_name, 'color', f'{stem_name_1}.jpg') # TODO: Support augmentation & handle seeds for each worker correctly. image0 = read_scannet_gray(img_name0, resize=self.img_resize, augment_fn=None) # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) image1 = read_scannet_gray(img_name1, resize=self.img_resize, augment_fn=None) # augment_fn=np.random.choice([self.augment_fn, None], p=[0.5, 0.5])) # read the depthmap which is stored as (480, 640) if self.mode in ['train', 'val']: depth0 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_0}.png')) depth1 = read_scannet_depth(osp.join(self.root_dir, scene_name, 'depth', f'{stem_name_1}.png')) else: depth0 = depth1 = torch.tensor([]) # read the intrinsic of depthmap K_0 = K_1 = torch.tensor(self.intrinsics[scene_name].copy(), dtype=torch.float).reshape(3, 3) # read and compute relative poses T_0to1 = torch.tensor(self._compute_rel_pose(scene_name, stem_name_0, stem_name_1), dtype=torch.float32) T_1to0 = T_0to1.inverse() h_new, w_new = self.img_resize[1], self.img_resize[0] scale0 = torch.tensor([640/w_new, 480/h_new], dtype=torch.float) scale1 = torch.tensor([640/w_new, 480/h_new], dtype=torch.float) if self.fp16: image0, image1, depth0, depth1, scale0, scale1 = map(lambda x: x.half(), [image0, image1, depth0, depth1, scale0, scale1]) data = { 'image0': image0, # (1, h, w) 'depth0': depth0, # (h, w) 'image1': image1, 'depth1': depth1, 'T_0to1': T_0to1, # (4, 4) 'T_1to0': T_1to0, 'K0': K_0, # (3, 3) 'K1': K_1, 'scale0': scale0, # [scale_w, scale_h] 'scale1': scale1, 'dataset_name': 'ScanNet', 'scene_id': scene_name, 'pair_id': idx, 'pair_names': (osp.join(scene_name, 'color', f'{stem_name_0}.jpg'), osp.join(scene_name, 'color', f'{stem_name_1}.jpg')) } return data