File size: 5,257 Bytes
9223079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from pathlib import Path
import numpy as np
import torch
import PIL.Image
from tqdm import tqdm
import pycolmap
from ...utils.read_write_model import write_model, read_model
def scene_coordinates(p2D, R_w2c, t_w2c, depth, camera):
assert len(depth) == len(p2D)
ret = pycolmap.image_to_world(p2D, camera._asdict())
p2D_norm = np.asarray(ret["world_points"])
p2D_h = np.concatenate([p2D_norm, np.ones_like(p2D_norm[:, :1])], 1)
p3D_c = p2D_h * depth[:, None]
p3D_w = (p3D_c - t_w2c) @ R_w2c
return p3D_w
def interpolate_depth(depth, kp):
h, w = depth.shape
kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1
assert np.all(kp > -1) and np.all(kp < 1)
depth = torch.from_numpy(depth)[None, None]
kp = torch.from_numpy(kp)[None, None]
grid_sample = torch.nn.functional.grid_sample
# To maximize the number of points that have depth:
# do bilinear interpolation first and then nearest for the remaining points
interp_lin = grid_sample(depth, kp, align_corners=True, mode="bilinear")[0, :, 0]
interp_nn = torch.nn.functional.grid_sample(
depth, kp, align_corners=True, mode="nearest"
)[0, :, 0]
interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin)
valid = ~torch.any(torch.isnan(interp), 0)
interp_depth = interp.T.numpy().flatten()
valid = valid.numpy()
return interp_depth, valid
def image_path_to_rendered_depth_path(image_name):
parts = image_name.split("/")
name = "_".join(["".join(parts[0].split("-")), parts[1]])
name = name.replace("color", "pose")
name = name.replace("png", "depth.tiff")
return name
def project_to_image(p3D, R, t, camera, eps: float = 1e-4, pad: int = 1):
p3D = (p3D @ R.T) + t
visible = p3D[:, -1] >= eps # keep points in front of the camera
p2D_norm = p3D[:, :-1] / p3D[:, -1:].clip(min=eps)
ret = pycolmap.world_to_image(p2D_norm, camera._asdict())
p2D = np.asarray(ret["image_points"])
size = np.array([camera.width - pad - 1, camera.height - pad - 1])
valid = np.all((p2D >= pad) & (p2D <= size), -1)
valid &= visible
return p2D[valid], valid
def correct_sfm_with_gt_depth(sfm_path, depth_folder_path, output_path):
cameras, images, points3D = read_model(sfm_path)
for imgid, img in tqdm(images.items()):
image_name = img.name
depth_name = image_path_to_rendered_depth_path(image_name)
depth = PIL.Image.open(Path(depth_folder_path) / depth_name)
depth = np.array(depth).astype("float64")
depth = depth / 1000.0 # mm to meter
depth[(depth == 0.0) | (depth > 1000.0)] = np.nan
R_w2c, t_w2c = img.qvec2rotmat(), img.tvec
camera = cameras[img.camera_id]
p3D_ids = img.point3D_ids
p3Ds = np.stack([points3D[i].xyz for i in p3D_ids[p3D_ids != -1]], 0)
p2Ds, valids_projected = project_to_image(p3Ds, R_w2c, t_w2c, camera)
invalid_p3D_ids = p3D_ids[p3D_ids != -1][~valids_projected]
interp_depth, valids_backprojected = interpolate_depth(depth, p2Ds)
scs = scene_coordinates(
p2Ds[valids_backprojected],
R_w2c,
t_w2c,
interp_depth[valids_backprojected],
camera,
)
invalid_p3D_ids = np.append(
invalid_p3D_ids,
p3D_ids[p3D_ids != -1][valids_projected][~valids_backprojected],
)
for p3did in invalid_p3D_ids:
if p3did == -1:
continue
else:
obs_imgids = points3D[p3did].image_ids
invalid_imgids = list(np.where(obs_imgids == img.id)[0])
points3D[p3did] = points3D[p3did]._replace(
image_ids=np.delete(obs_imgids, invalid_imgids),
point2D_idxs=np.delete(
points3D[p3did].point2D_idxs, invalid_imgids
),
)
new_p3D_ids = p3D_ids.copy()
sub_p3D_ids = new_p3D_ids[new_p3D_ids != -1]
valids = np.ones(np.count_nonzero(new_p3D_ids != -1), dtype=bool)
valids[~valids_projected] = False
valids[valids_projected] = valids_backprojected
sub_p3D_ids[~valids] = -1
new_p3D_ids[new_p3D_ids != -1] = sub_p3D_ids
img = img._replace(point3D_ids=new_p3D_ids)
assert len(img.point3D_ids[img.point3D_ids != -1]) == len(
scs
), f"{len(scs)}, {len(img.point3D_ids[img.point3D_ids != -1])}"
for i, p3did in enumerate(img.point3D_ids[img.point3D_ids != -1]):
points3D[p3did] = points3D[p3did]._replace(xyz=scs[i])
images[imgid] = img
output_path.mkdir(parents=True, exist_ok=True)
write_model(cameras, images, points3D, output_path)
if __name__ == "__main__":
dataset = Path("datasets/7scenes")
outputs = Path("outputs/7Scenes")
SCENES = ["chess", "fire", "heads", "office", "pumpkin", "redkitchen", "stairs"]
for scene in SCENES:
sfm_path = outputs / scene / "sfm_superpoint+superglue"
depth_path = dataset / f"depth/7scenes_{scene}/train/depth"
output_path = outputs / scene / "sfm_superpoint+superglue+depth"
correct_sfm_with_gt_depth(sfm_path, depth_path, output_path)
|