File size: 7,089 Bytes
f53b39e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Dataloader for preprocessed Co3d_v2
# dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International
# See datasets_preprocess/preprocess_co3d.py
# --------------------------------------------------------
import os.path as osp
import json
import itertools
from collections import deque
import cv2
import numpy as np
from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
from dust3r.utils.image import imread_cv2
class Co3d(BaseStereoViewDataset):
def __init__(self, mask_bg=True, *args, ROOT, **kwargs):
self.ROOT = ROOT
super().__init__(*args, **kwargs)
assert mask_bg in (True, False, 'rand')
self.mask_bg = mask_bg
self.dataset_label = 'Co3d_v2'
# load all scenes
with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f:
self.scenes = json.load(f)
self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0}
self.scenes = {(k, k2): v2 for k, v in self.scenes.items()
for k2, v2 in v.items()}
self.scene_list = list(self.scenes.keys())
# for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees)
# we prepare all combinations such that i-j = +/- [5, 10, .., 90] degrees
self.combinations = [(i, j)
for i, j in itertools.combinations(range(100), 2)
if 0 < abs(i - j) <= 30 and abs(i - j) % 5 == 0]
self.invalidate = {scene: {} for scene in self.scene_list}
def __len__(self):
return len(self.scene_list) * len(self.combinations)
def _get_metadatapath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.npz')
def _get_impath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg')
def _get_depthpath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, 'depths', f'frame{view_idx:06n}.jpg.geometric.png')
def _get_maskpath(self, obj, instance, view_idx):
return osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png')
def _read_depthmap(self, depthpath, input_metadata):
depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED)
depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(input_metadata['maximum_depth'])
return depthmap
def _get_views(self, idx, resolution, rng):
# choose a scene
obj, instance = self.scene_list[idx // len(self.combinations)]
image_pool = self.scenes[obj, instance]
im1_idx, im2_idx = self.combinations[idx % len(self.combinations)]
# add a bit of randomness
last = len(image_pool) - 1
if resolution not in self.invalidate[obj, instance]: # flag invalid images
self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))]
# decide now if we mask the bg
mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2))
views = []
imgs_idxs = [max(0, min(im_idx + rng.integers(-4, 5), last)) for im_idx in [im2_idx, im1_idx]]
imgs_idxs = deque(imgs_idxs)
while len(imgs_idxs) > 0: # some images (few) have zero depth
im_idx = imgs_idxs.pop()
if self.invalidate[obj, instance][resolution][im_idx]:
# search for a valid image
random_direction = 2 * rng.choice(2) - 1
for offset in range(1, len(image_pool)):
tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool)
if not self.invalidate[obj, instance][resolution][tentative_im_idx]:
im_idx = tentative_im_idx
break
view_idx = image_pool[im_idx]
impath = self._get_impath(obj, instance, view_idx)
depthpath = self._get_depthpath(obj, instance, view_idx)
# load camera params
metadata_path = self._get_metadatapath(obj, instance, view_idx)
input_metadata = np.load(metadata_path)
camera_pose = input_metadata['camera_pose'].astype(np.float32)
intrinsics = input_metadata['camera_intrinsics'].astype(np.float32)
# load image and depth
rgb_image = imread_cv2(impath)
depthmap = self._read_depthmap(depthpath, input_metadata)
if mask_bg:
# load object mask
maskpath = self._get_maskpath(obj, instance, view_idx)
maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32)
maskmap = (maskmap / 255.0) > 0.1
# update the depthmap with mask
depthmap *= maskmap
rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary(
rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath)
num_valid = (depthmap > 0.0).sum()
if num_valid == 0:
# problem, invalidate image and retry
self.invalidate[obj, instance][resolution][im_idx] = True
imgs_idxs.append(im_idx)
continue
views.append(dict(
img=rgb_image,
depthmap=depthmap,
camera_pose=camera_pose,
camera_intrinsics=intrinsics,
dataset=self.dataset_label,
label=osp.join(obj, instance),
instance=osp.split(impath)[1],
))
return views
if __name__ == "__main__":
from dust3r.datasets.base.base_stereo_view_dataset import view_name
from dust3r.viz import SceneViz, auto_cam_size
from dust3r.utils.image import rgb
dataset = Co3d(split='train', ROOT="data/co3d_subset_processed", resolution=224, aug_crop=16)
for idx in np.random.permutation(len(dataset)):
views = dataset[idx]
assert len(views) == 2
print(view_name(views[0]), view_name(views[1]))
viz = SceneViz()
poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]
cam_size = max(auto_cam_size(poses), 0.001)
for view_idx in [0, 1]:
pts3d = views[view_idx]['pts3d']
valid_mask = views[view_idx]['valid_mask']
colors = rgb(views[view_idx]['img'])
viz.add_pointcloud(pts3d, colors, valid_mask)
viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],
focal=views[view_idx]['camera_intrinsics'][0, 0],
color=(idx * 255, (1 - idx) * 255, 0),
image=colors,
cam_size=cam_size)
viz.show()
|