Spaces:
Running
Running
File size: 7,220 Bytes
fb98d2a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# InLoc dataloader
# --------------------------------------------------------
import os
import numpy as np
import torch
import PIL.Image
import scipy.io
import kapture
from kapture.io.csv import kapture_from_dir
from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file
from dust3r_visloc.datasets.utils import cam_to_world_from_kapture, get_resize_function, rescale_points3d
from dust3r_visloc.datasets.base_dataset import BaseVislocDataset
from dust3r.datasets.utils.transforms import ImgNorm
from dust3r.utils.geometry import xy_grid, geotrf
def read_alignments(path_to_alignment):
aligns = {}
with open(path_to_alignment, "r") as fid:
while True:
line = fid.readline()
if not line:
break
if len(line) == 4:
trans_nr = line[:-1]
while line != 'After general icp:\n':
line = fid.readline()
line = fid.readline()
p = []
for i in range(4):
elems = line.split(' ')
line = fid.readline()
for e in elems:
if len(e) != 0:
p.append(float(e))
P = np.array(p).reshape(4, 4)
aligns[trans_nr] = P
return aligns
class VislocInLoc(BaseVislocDataset):
def __init__(self, root, pairsfile, topk=1):
super().__init__()
self.root = root
self.topk = topk
self.num_views = self.topk + 1
self.maxdim = None
self.patch_size = None
query_path = os.path.join(self.root, 'query')
kdata_query = kapture_from_dir(query_path)
assert kdata_query.records_camera is not None
kdata_query_searchindex = {kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id)
for timestamp, sensor_id in kdata_query.records_camera.key_pairs()}
self.query_data = {'path': query_path, 'kdata': kdata_query, 'searchindex': kdata_query_searchindex}
map_path = os.path.join(self.root, 'mapping')
kdata_map = kapture_from_dir(map_path)
assert kdata_map.records_camera is not None and kdata_map.trajectories is not None
kdata_map_searchindex = {kdata_map.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id)
for timestamp, sensor_id in kdata_map.records_camera.key_pairs()}
self.map_data = {'path': map_path, 'kdata': kdata_map, 'searchindex': kdata_map_searchindex}
try:
self.pairs = get_ordered_pairs_from_file(os.path.join(self.root, 'pairfiles/query', pairsfile + '.txt'))
except Exception as e:
# if using pairs from hloc
self.pairs = {}
with open(os.path.join(self.root, 'pairfiles/query', pairsfile + '.txt'), 'r') as fid:
lines = fid.readlines()
for line in lines:
splits = line.rstrip("\n\r").split(" ")
self.pairs.setdefault(splits[0].replace('query/', ''), []).append(
(splits[1].replace('database/cutouts/', ''), 1.0)
)
self.scenes = kdata_query.records_camera.data_list()
self.aligns_DUC1 = read_alignments(os.path.join(self.root, 'mapping/DUC1_alignment/all_transformations.txt'))
self.aligns_DUC2 = read_alignments(os.path.join(self.root, 'mapping/DUC2_alignment/all_transformations.txt'))
def __len__(self):
return len(self.scenes)
def __getitem__(self, idx):
assert self.maxdim is not None and self.patch_size is not None
query_image = self.scenes[idx]
map_images = [p[0] for p in self.pairs[query_image][:self.topk]]
views = []
dataarray = [(query_image, self.query_data, False)] + [(map_image, self.map_data, True)
for map_image in map_images]
for idx, (imgname, data, should_load_depth) in enumerate(dataarray):
imgpath, kdata, searchindex = map(data.get, ['path', 'kdata', 'searchindex'])
timestamp, camera_id = searchindex[imgname]
# for InLoc, SIMPLE_PINHOLE
camera_params = kdata.sensors[camera_id].camera_params
W, H, f, cx, cy = camera_params
distortion = [0, 0, 0, 0]
intrinsics = np.float32([(f, 0, cx),
(0, f, cy),
(0, 0, 1)])
if kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories:
cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id)
else:
cam_to_world = np.eye(4, dtype=np.float32)
# Load RGB image
rgb_image = PIL.Image.open(os.path.join(imgpath, 'sensors/records_data', imgname)).convert('RGB')
rgb_image.load()
W, H = rgb_image.size
resize_func, to_resize, to_orig = get_resize_function(self.maxdim, self.patch_size, H, W)
rgb_tensor = resize_func(ImgNorm(rgb_image))
view = {
'intrinsics': intrinsics,
'distortion': distortion,
'cam_to_world': cam_to_world,
'rgb': rgb_image,
'rgb_rescaled': rgb_tensor,
'to_orig': to_orig,
'idx': idx,
'image_name': imgname
}
# Load depthmap
if should_load_depth:
depthmap_filename = os.path.join(imgpath, 'sensors/records_data', imgname + '.mat')
depthmap = scipy.io.loadmat(depthmap_filename)
pt3d_cut = depthmap['XYZcut']
scene_id = imgname.replace('\\', '/').split('/')[1]
if imgname.startswith('DUC1'):
pts3d_full = geotrf(self.aligns_DUC1[scene_id], pt3d_cut)
else:
pts3d_full = geotrf(self.aligns_DUC2[scene_id], pt3d_cut)
pts3d_valid = np.isfinite(pts3d_full.sum(axis=-1))
pts3d = pts3d_full[pts3d_valid]
pts2d_int = xy_grid(W, H)[pts3d_valid]
pts2d = pts2d_int.astype(np.float64)
# nan => invalid
pts3d_full[~pts3d_valid] = np.nan
pts3d_full = torch.from_numpy(pts3d_full)
view['pts3d'] = pts3d_full
view["valid"] = pts3d_full.sum(dim=-1).isfinite()
HR, WR = rgb_tensor.shape[1:]
_, _, pts3d_rescaled, valid_rescaled = rescale_points3d(pts2d, pts3d, to_resize, HR, WR)
pts3d_rescaled = torch.from_numpy(pts3d_rescaled)
valid_rescaled = torch.from_numpy(valid_rescaled)
view['pts3d_rescaled'] = pts3d_rescaled
view["valid_rescaled"] = valid_rescaled
views.append(view)
return views
|