|
from pathlib import Path |
|
import numpy as np |
|
import torch |
|
import PIL.Image |
|
from tqdm import tqdm |
|
import pycolmap |
|
|
|
from ...utils.read_write_model import write_model, read_model |
|
|
|
|
|
def scene_coordinates(p2D, R_w2c, t_w2c, depth, camera): |
|
assert len(depth) == len(p2D) |
|
ret = pycolmap.image_to_world(p2D, camera._asdict()) |
|
p2D_norm = np.asarray(ret["world_points"]) |
|
p2D_h = np.concatenate([p2D_norm, np.ones_like(p2D_norm[:, :1])], 1) |
|
p3D_c = p2D_h * depth[:, None] |
|
p3D_w = (p3D_c - t_w2c) @ R_w2c |
|
return p3D_w |
|
|
|
|
|
def interpolate_depth(depth, kp): |
|
h, w = depth.shape |
|
kp = kp / np.array([[w - 1, h - 1]]) * 2 - 1 |
|
assert np.all(kp > -1) and np.all(kp < 1) |
|
depth = torch.from_numpy(depth)[None, None] |
|
kp = torch.from_numpy(kp)[None, None] |
|
grid_sample = torch.nn.functional.grid_sample |
|
|
|
|
|
|
|
interp_lin = grid_sample(depth, kp, align_corners=True, mode="bilinear")[ |
|
0, :, 0 |
|
] |
|
interp_nn = torch.nn.functional.grid_sample( |
|
depth, kp, align_corners=True, mode="nearest" |
|
)[0, :, 0] |
|
interp = torch.where(torch.isnan(interp_lin), interp_nn, interp_lin) |
|
valid = ~torch.any(torch.isnan(interp), 0) |
|
|
|
interp_depth = interp.T.numpy().flatten() |
|
valid = valid.numpy() |
|
return interp_depth, valid |
|
|
|
|
|
def image_path_to_rendered_depth_path(image_name): |
|
parts = image_name.split("/") |
|
name = "_".join(["".join(parts[0].split("-")), parts[1]]) |
|
name = name.replace("color", "pose") |
|
name = name.replace("png", "depth.tiff") |
|
return name |
|
|
|
|
|
def project_to_image(p3D, R, t, camera, eps: float = 1e-4, pad: int = 1): |
|
p3D = (p3D @ R.T) + t |
|
visible = p3D[:, -1] >= eps |
|
p2D_norm = p3D[:, :-1] / p3D[:, -1:].clip(min=eps) |
|
ret = pycolmap.world_to_image(p2D_norm, camera._asdict()) |
|
p2D = np.asarray(ret["image_points"]) |
|
size = np.array([camera.width - pad - 1, camera.height - pad - 1]) |
|
valid = np.all((p2D >= pad) & (p2D <= size), -1) |
|
valid &= visible |
|
return p2D[valid], valid |
|
|
|
|
|
def correct_sfm_with_gt_depth(sfm_path, depth_folder_path, output_path): |
|
cameras, images, points3D = read_model(sfm_path) |
|
for imgid, img in tqdm(images.items()): |
|
image_name = img.name |
|
depth_name = image_path_to_rendered_depth_path(image_name) |
|
|
|
depth = PIL.Image.open(Path(depth_folder_path) / depth_name) |
|
depth = np.array(depth).astype("float64") |
|
depth = depth / 1000.0 |
|
depth[(depth == 0.0) | (depth > 1000.0)] = np.nan |
|
|
|
R_w2c, t_w2c = img.qvec2rotmat(), img.tvec |
|
camera = cameras[img.camera_id] |
|
p3D_ids = img.point3D_ids |
|
p3Ds = np.stack([points3D[i].xyz for i in p3D_ids[p3D_ids != -1]], 0) |
|
|
|
p2Ds, valids_projected = project_to_image(p3Ds, R_w2c, t_w2c, camera) |
|
invalid_p3D_ids = p3D_ids[p3D_ids != -1][~valids_projected] |
|
interp_depth, valids_backprojected = interpolate_depth(depth, p2Ds) |
|
scs = scene_coordinates( |
|
p2Ds[valids_backprojected], |
|
R_w2c, |
|
t_w2c, |
|
interp_depth[valids_backprojected], |
|
camera, |
|
) |
|
invalid_p3D_ids = np.append( |
|
invalid_p3D_ids, |
|
p3D_ids[p3D_ids != -1][valids_projected][~valids_backprojected], |
|
) |
|
for p3did in invalid_p3D_ids: |
|
if p3did == -1: |
|
continue |
|
else: |
|
obs_imgids = points3D[p3did].image_ids |
|
invalid_imgids = list(np.where(obs_imgids == img.id)[0]) |
|
points3D[p3did] = points3D[p3did]._replace( |
|
image_ids=np.delete(obs_imgids, invalid_imgids), |
|
point2D_idxs=np.delete( |
|
points3D[p3did].point2D_idxs, invalid_imgids |
|
), |
|
) |
|
|
|
new_p3D_ids = p3D_ids.copy() |
|
sub_p3D_ids = new_p3D_ids[new_p3D_ids != -1] |
|
valids = np.ones(np.count_nonzero(new_p3D_ids != -1), dtype=bool) |
|
valids[~valids_projected] = False |
|
valids[valids_projected] = valids_backprojected |
|
sub_p3D_ids[~valids] = -1 |
|
new_p3D_ids[new_p3D_ids != -1] = sub_p3D_ids |
|
img = img._replace(point3D_ids=new_p3D_ids) |
|
|
|
assert len(img.point3D_ids[img.point3D_ids != -1]) == len( |
|
scs |
|
), f"{len(scs)}, {len(img.point3D_ids[img.point3D_ids != -1])}" |
|
for i, p3did in enumerate(img.point3D_ids[img.point3D_ids != -1]): |
|
points3D[p3did] = points3D[p3did]._replace(xyz=scs[i]) |
|
images[imgid] = img |
|
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
write_model(cameras, images, points3D, output_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
dataset = Path("datasets/7scenes") |
|
outputs = Path("outputs/7Scenes") |
|
|
|
SCENES = [ |
|
"chess", |
|
"fire", |
|
"heads", |
|
"office", |
|
"pumpkin", |
|
"redkitchen", |
|
"stairs", |
|
] |
|
for scene in SCENES: |
|
sfm_path = outputs / scene / "sfm_superpoint+superglue" |
|
depth_path = dataset / f"depth/7scenes_{scene}/train/depth" |
|
output_path = outputs / scene / "sfm_superpoint+superglue+depth" |
|
correct_sfm_with_gt_depth(sfm_path, depth_path, output_path) |
|
|