|
import os |
|
import random |
|
from PIL import Image |
|
import cv2 |
|
import h5py |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader, ConcatDataset |
|
|
|
import torchvision.transforms.functional as tvf |
|
import kornia.augmentation as K |
|
import os.path as osp |
|
import matplotlib.pyplot as plt |
|
import roma |
|
from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops |
|
from roma.utils.transforms import GeometricSequential |
|
from tqdm import tqdm |
|
|
|
|
|
class ScanNetScene: |
|
def __init__( |
|
self, |
|
data_root, |
|
scene_info, |
|
ht=384, |
|
wt=512, |
|
min_overlap=0.0, |
|
shake_t=0, |
|
rot_prob=0.0, |
|
use_horizontal_flip_aug=False, |
|
) -> None: |
|
self.scene_root = osp.join(data_root, "scans", "scans_train") |
|
self.data_names = scene_info["name"] |
|
self.overlaps = scene_info["score"] |
|
|
|
valid = (self.data_names[:, -2:] % 10).sum(axis=-1) == 0 |
|
self.overlaps = self.overlaps[valid] |
|
self.data_names = self.data_names[valid] |
|
if len(self.data_names) > 10000: |
|
pairinds = np.random.choice( |
|
np.arange(0, len(self.data_names)), 10000, replace=False |
|
) |
|
self.data_names = self.data_names[pairinds] |
|
self.overlaps = self.overlaps[pairinds] |
|
self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True) |
|
self.depth_transform_ops = get_depth_tuple_transform_ops( |
|
resize=(ht, wt), normalize=False |
|
) |
|
self.wt, self.ht = wt, ht |
|
self.shake_t = shake_t |
|
self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob)) |
|
self.use_horizontal_flip_aug = use_horizontal_flip_aug |
|
|
|
def load_im(self, im_B, crop=None): |
|
im = Image.open(im_B) |
|
return im |
|
|
|
def load_depth(self, depth_ref, crop=None): |
|
depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED) |
|
depth = depth / 1000 |
|
depth = torch.from_numpy(depth).float() |
|
return depth |
|
|
|
def __len__(self): |
|
return len(self.data_names) |
|
|
|
def scale_intrinsic(self, K, wi, hi): |
|
sx, sy = self.wt / wi, self.ht / hi |
|
sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]]) |
|
return sK @ K |
|
|
|
def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B): |
|
im_A = im_A.flip(-1) |
|
im_B = im_B.flip(-1) |
|
depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) |
|
flip_mat = torch.tensor([[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.0]]).to( |
|
K_A.device |
|
) |
|
K_A = flip_mat @ K_A |
|
K_B = flip_mat @ K_B |
|
|
|
return im_A, im_B, depth_A, depth_B, K_A, K_B |
|
|
|
def read_scannet_pose(self, path): |
|
"""Read ScanNet's Camera2World pose and transform it to World2Camera. |
|
|
|
Returns: |
|
pose_w2c (np.ndarray): (4, 4) |
|
""" |
|
cam2world = np.loadtxt(path, delimiter=" ") |
|
world2cam = np.linalg.inv(cam2world) |
|
return world2cam |
|
|
|
def read_scannet_intrinsic(self, path): |
|
"""Read ScanNet's intrinsic matrix and return the 3x3 matrix.""" |
|
intrinsic = np.loadtxt(path, delimiter=" ") |
|
return torch.tensor(intrinsic[:-1, :-1], dtype=torch.float) |
|
|
|
def __getitem__(self, pair_idx): |
|
|
|
data_name = self.data_names[pair_idx] |
|
scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name |
|
scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}" |
|
|
|
|
|
K1 = K2 = self.read_scannet_intrinsic( |
|
osp.join(self.scene_root, scene_name, "intrinsic", "intrinsic_color.txt") |
|
) |
|
|
|
T1 = self.read_scannet_pose( |
|
osp.join(self.scene_root, scene_name, "pose", f"{stem_name_1}.txt") |
|
) |
|
T2 = self.read_scannet_pose( |
|
osp.join(self.scene_root, scene_name, "pose", f"{stem_name_2}.txt") |
|
) |
|
T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[ |
|
:4, :4 |
|
] |
|
|
|
|
|
im_A_ref = os.path.join( |
|
self.scene_root, scene_name, "color", f"{stem_name_1}.jpg" |
|
) |
|
im_B_ref = os.path.join( |
|
self.scene_root, scene_name, "color", f"{stem_name_2}.jpg" |
|
) |
|
depth_A_ref = os.path.join( |
|
self.scene_root, scene_name, "depth", f"{stem_name_1}.png" |
|
) |
|
depth_B_ref = os.path.join( |
|
self.scene_root, scene_name, "depth", f"{stem_name_2}.png" |
|
) |
|
|
|
im_A = self.load_im(im_A_ref) |
|
im_B = self.load_im(im_B_ref) |
|
depth_A = self.load_depth(depth_A_ref) |
|
depth_B = self.load_depth(depth_B_ref) |
|
|
|
|
|
K1 = self.scale_intrinsic(K1, im_A.width, im_A.height) |
|
K2 = self.scale_intrinsic(K2, im_B.width, im_B.height) |
|
|
|
im_A, im_B = self.im_transform_ops((im_A, im_B)) |
|
depth_A, depth_B = self.depth_transform_ops( |
|
(depth_A[None, None], depth_B[None, None]) |
|
) |
|
if self.use_horizontal_flip_aug: |
|
if np.random.rand() > 0.5: |
|
im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip( |
|
im_A, im_B, depth_A, depth_B, K1, K2 |
|
) |
|
|
|
data_dict = { |
|
"im_A": im_A, |
|
"im_B": im_B, |
|
"im_A_depth": depth_A[0, 0], |
|
"im_B_depth": depth_B[0, 0], |
|
"K1": K1, |
|
"K2": K2, |
|
"T_1to2": T_1to2, |
|
} |
|
return data_dict |
|
|
|
|
|
class ScanNetBuilder: |
|
def __init__(self, data_root="data/scannet") -> None: |
|
self.data_root = data_root |
|
self.scene_info_root = os.path.join(data_root, "scannet_indices") |
|
self.all_scenes = os.listdir(self.scene_info_root) |
|
|
|
def build_scenes(self, split="train", min_overlap=0.0, **kwargs): |
|
|
|
scene_names = self.all_scenes |
|
scenes = [] |
|
for scene_name in tqdm(scene_names, disable=roma.RANK > 0): |
|
scene_info = np.load( |
|
os.path.join(self.scene_info_root, scene_name), allow_pickle=True |
|
) |
|
scenes.append( |
|
ScanNetScene( |
|
self.data_root, scene_info, min_overlap=min_overlap, **kwargs |
|
) |
|
) |
|
return scenes |
|
|
|
def weight_scenes(self, concat_dataset, alpha=0.5): |
|
ns = [] |
|
for d in concat_dataset.datasets: |
|
ns.append(len(d)) |
|
ws = torch.cat([torch.ones(n) / n**alpha for n in ns]) |
|
return ws |
|
|