|
import os, torch |
|
import numpy as np |
|
|
|
import torch.nn.functional as F |
|
|
|
def build_patch_offset(h_patch_size): |
|
offsets = torch.arange(-h_patch_size, h_patch_size + 1) |
|
return torch.stack(torch.meshgrid(offsets, offsets)[::-1], dim=-1).view(1, -1, 2) |
|
|
|
|
|
def gen_rays_from_single_image(H, W, image, intrinsic, c2w, depth=None, mask=None): |
|
""" |
|
generate rays in world space, for image image |
|
:param H: |
|
:param W: |
|
:param intrinsics: [3,3] |
|
:param c2ws: [4,4] |
|
:return: |
|
""" |
|
device = image.device |
|
ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), |
|
torch.linspace(0, W - 1, W), indexing="ij") |
|
p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) |
|
|
|
|
|
ndc_u = 2 * xs / (W - 1) - 1 |
|
ndc_v = 2 * ys / (H - 1) - 1 |
|
rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device) |
|
|
|
intrinsic_inv = torch.inverse(intrinsic) |
|
|
|
p = p.view(-1, 3).float().to(device) |
|
p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() |
|
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) |
|
rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() |
|
rays_o = c2w[None, :3, 3].expand(rays_v.shape) |
|
|
|
image = image.permute(1, 2, 0) |
|
color = image.view(-1, 3) |
|
depth = depth.view(-1, 1) if depth is not None else None |
|
mask = mask.view(-1, 1) if mask is not None else torch.ones([H * W, 1]).to(device) |
|
sample = { |
|
'rays_o': rays_o, |
|
'rays_v': rays_v, |
|
'rays_ndc_uv': rays_ndc_uv, |
|
'rays_color': color, |
|
|
|
'rays_mask': mask, |
|
'rays_norm_XYZ_cam': p |
|
} |
|
if depth is not None: |
|
sample['rays_depth'] = depth |
|
|
|
return sample |
|
|
|
|
|
def gen_random_rays_from_single_image(H, W, N_rays, image, intrinsic, c2w, depth=None, mask=None, dilated_mask=None, |
|
importance_sample=False, h_patch_size=3): |
|
""" |
|
generate random rays in world space, for a single image |
|
:param H: |
|
:param W: |
|
:param N_rays: |
|
:param image: [3, H, W] |
|
:param intrinsic: [3,3] |
|
:param c2w: [4,4] |
|
:param depth: [H, W] |
|
:param mask: [H, W] |
|
:return: |
|
""" |
|
device = image.device |
|
|
|
if dilated_mask is None: |
|
dilated_mask = mask |
|
|
|
if not importance_sample: |
|
pixels_x = torch.randint(low=0, high=W, size=[N_rays]) |
|
pixels_y = torch.randint(low=0, high=H, size=[N_rays]) |
|
elif importance_sample and dilated_mask is not None: |
|
pixels_x_1 = torch.randint(low=0, high=W, size=[N_rays // 4]) |
|
pixels_y_1 = torch.randint(low=0, high=H, size=[N_rays // 4]) |
|
|
|
ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), |
|
torch.linspace(0, W - 1, W), indexing="ij") |
|
p = torch.stack([xs, ys], dim=-1) |
|
|
|
try: |
|
p_valid = p[dilated_mask > 0] |
|
random_idx = torch.randint(low=0, high=p_valid.shape[0], size=[N_rays // 4 * 3]) |
|
except: |
|
print("dilated_mask.shape: ", dilated_mask.shape) |
|
print("dilated_mask valid number", dilated_mask.sum()) |
|
|
|
raise ValueError("hhhh") |
|
p_select = p_valid[random_idx] |
|
pixels_x_2 = p_select[:, 0] |
|
pixels_y_2 = p_select[:, 1] |
|
|
|
pixels_x = torch.cat([pixels_x_1, pixels_x_2], dim=0).to(torch.int64) |
|
pixels_y = torch.cat([pixels_y_1, pixels_y_2], dim=0).to(torch.int64) |
|
|
|
|
|
offsets = build_patch_offset(h_patch_size).to(device) |
|
grid_patch = torch.stack([pixels_x, pixels_y], dim=-1).view(-1, 1, 2) + offsets.float() |
|
patch_mask = (pixels_x > h_patch_size) * (pixels_x < (W - h_patch_size)) * (pixels_y > h_patch_size) * ( |
|
pixels_y < H - h_patch_size) |
|
grid_patch_u = 2 * grid_patch[:, :, 0] / (W - 1) - 1 |
|
grid_patch_v = 2 * grid_patch[:, :, 1] / (H - 1) - 1 |
|
grid_patch_uv = torch.stack([grid_patch_u, grid_patch_v], dim=-1) |
|
patch_color = F.grid_sample(image[None, :, :, :], grid_patch_uv[None, :, :, :], mode='bilinear', |
|
padding_mode='zeros',align_corners=True)[0] |
|
patch_color = patch_color.permute(1, 2, 0).contiguous() |
|
|
|
|
|
ndc_u = 2 * pixels_x / (W - 1) - 1 |
|
ndc_v = 2 * pixels_y / (H - 1) - 1 |
|
rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device) |
|
|
|
image = image.permute(1, 2, 0) |
|
color = image[(pixels_y, pixels_x)] |
|
|
|
if mask is not None: |
|
mask = mask[(pixels_y, pixels_x)] |
|
patch_mask = patch_mask * mask |
|
mask = mask.view(-1, 1) |
|
else: |
|
mask = torch.ones([N_rays, 1]) |
|
|
|
if depth is not None: |
|
depth = depth[(pixels_y, pixels_x)] |
|
depth = depth.view(-1, 1) |
|
|
|
intrinsic_inv = torch.inverse(intrinsic) |
|
|
|
p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) |
|
p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() |
|
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) |
|
rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() |
|
rays_o = c2w[None, :3, 3].expand(rays_v.shape) |
|
|
|
sample = { |
|
'rays_o': rays_o, |
|
'rays_v': rays_v, |
|
'rays_ndc_uv': rays_ndc_uv, |
|
'rays_color': color, |
|
|
|
'rays_mask': mask, |
|
'rays_norm_XYZ_cam': p, |
|
'rays_patch_color': patch_color, |
|
'rays_patch_mask': patch_mask.view(-1, 1) |
|
} |
|
|
|
if depth is not None: |
|
sample['rays_depth'] = depth |
|
|
|
return sample |
|
|
|
|
|
def gen_random_rays_of_patch_from_single_image(H, W, N_rays, num_neighboring_pts, patch_size, |
|
image, intrinsic, c2w, depth=None, mask=None): |
|
""" |
|
generate random rays in world space, for a single image |
|
sample rays from local patches |
|
:param H: |
|
:param W: |
|
:param N_rays: the number of center rays of patches |
|
:param image: [3, H, W] |
|
:param intrinsic: [3,3] |
|
:param c2w: [4,4] |
|
:param depth: [H, W] |
|
:param mask: [H, W] |
|
:return: |
|
""" |
|
device = image.device |
|
patch_radius_max = patch_size // 2 |
|
|
|
unit_u = 2 / (W - 1) |
|
unit_v = 2 / (H - 1) |
|
|
|
pixels_x_center = torch.randint(low=patch_size, high=W - patch_size, size=[N_rays]) |
|
pixels_y_center = torch.randint(low=patch_size, high=H - patch_size, size=[N_rays]) |
|
|
|
|
|
ndc_u_center = 2 * pixels_x_center / (W - 1) - 1 |
|
ndc_v_center = 2 * pixels_y_center / (H - 1) - 1 |
|
ndc_uv_center = torch.stack([ndc_u_center, ndc_v_center], dim=-1).view(-1, 2).float().to(device)[:, None, |
|
:] |
|
|
|
shift_u, shift_v = torch.rand([N_rays, num_neighboring_pts, 1]), torch.rand( |
|
[N_rays, num_neighboring_pts, 1]) |
|
shift_u = 2 * (shift_u - 0.5) |
|
shift_v = 2 * (shift_v - 0.5) |
|
|
|
|
|
shift_uv = torch.cat([(shift_u * patch_radius_max) * unit_u, (shift_v * patch_radius_max) * unit_v], |
|
dim=-1) |
|
neighboring_pts_uv = ndc_uv_center + shift_uv |
|
|
|
sampled_pts_uv = torch.cat([ndc_uv_center, neighboring_pts_uv], dim=1) |
|
|
|
|
|
color = F.grid_sample(image[None, :, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear', |
|
align_corners=True)[0] |
|
depth = F.grid_sample(depth[None, None, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear', |
|
align_corners=True)[0] |
|
|
|
mask = F.grid_sample(mask[None, None, :, :].to(torch.float32), sampled_pts_uv[None, :, :, :], mode='nearest', |
|
align_corners=True).to(torch.int64)[0] |
|
|
|
intrinsic_inv = torch.inverse(intrinsic) |
|
|
|
sampled_pts_uv = sampled_pts_uv.view(N_rays * (1 + num_neighboring_pts), 2) |
|
color = color.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 3) |
|
depth = depth.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1) |
|
mask = mask.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1) |
|
|
|
pixels_x = (sampled_pts_uv[:, 0] + 1) * (W - 1) / 2 |
|
pixels_y = (sampled_pts_uv[:, 1] + 1) * (H - 1) / 2 |
|
p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) |
|
p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() |
|
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) |
|
rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() |
|
rays_o = c2w[None, :3, 3].expand(rays_v.shape) |
|
|
|
sample = { |
|
'rays_o': rays_o, |
|
'rays_v': rays_v, |
|
'rays_ndc_uv': sampled_pts_uv, |
|
'rays_color': color, |
|
'rays_depth': depth, |
|
'rays_mask': mask, |
|
|
|
} |
|
|
|
return sample |
|
|
|
|
|
def gen_random_rays_from_batch_images(H, W, N_rays, images, intrinsics, c2ws, depths=None, masks=None): |
|
""" |
|
|
|
:param H: |
|
:param W: |
|
:param N_rays: |
|
:param images: [B,3,H,W] |
|
:param intrinsics: [B, 3, 3] |
|
:param c2ws: [B, 4, 4] |
|
:param depths: [B,H,W] |
|
:param masks: [B,H,W] |
|
:return: |
|
""" |
|
assert len(images.shape) == 4 |
|
|
|
rays_o = [] |
|
rays_v = [] |
|
rays_color = [] |
|
rays_depth = [] |
|
rays_mask = [] |
|
for i in range(images.shape[0]): |
|
sample = gen_random_rays_from_single_image(H, W, N_rays, images[i], intrinsics[i], c2ws[i], |
|
depth=depths[i] if depths is not None else None, |
|
mask=masks[i] if masks is not None else None) |
|
rays_o.append(sample['rays_o']) |
|
rays_v.append(sample['rays_v']) |
|
rays_color.append(sample['rays_color']) |
|
if depths is not None: |
|
rays_depth.append(sample['rays_depth']) |
|
if masks is not None: |
|
rays_mask.append(sample['rays_mask']) |
|
|
|
sample = { |
|
'rays_o': torch.stack(rays_o, dim=0), |
|
'rays_v': torch.stack(rays_v, dim=0), |
|
'rays_color': torch.stack(rays_color, dim=0), |
|
'rays_depth': torch.stack(rays_depth, dim=0) if depths is not None else None, |
|
'rays_mask': torch.stack(rays_mask, dim=0) if masks is not None else None |
|
} |
|
return sample |
|
|
|
|
|
from scipy.spatial.transform import Rotation as Rot |
|
from scipy.spatial.transform import Slerp |
|
|
|
|
|
def gen_rays_between(c2w_0, c2w_1, intrinsic, ratio, H, W, resolution_level=1): |
|
device = c2w_0.device |
|
|
|
l = resolution_level |
|
tx = torch.linspace(0, W - 1, W // l) |
|
ty = torch.linspace(0, H - 1, H // l) |
|
pixels_x, pixels_y = torch.meshgrid(tx, ty, indexing="ij") |
|
p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).to(device) |
|
|
|
intrinsic_inv = torch.inverse(intrinsic[:3, :3]) |
|
p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() |
|
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) |
|
trans = c2w_0[:3, 3] * (1.0 - ratio) + c2w_1[:3, 3] * ratio |
|
|
|
pose_0 = c2w_0.detach().cpu().numpy() |
|
pose_1 = c2w_1.detach().cpu().numpy() |
|
pose_0 = np.linalg.inv(pose_0) |
|
pose_1 = np.linalg.inv(pose_1) |
|
rot_0 = pose_0[:3, :3] |
|
rot_1 = pose_1[:3, :3] |
|
rots = Rot.from_matrix(np.stack([rot_0, rot_1])) |
|
key_times = [0, 1] |
|
key_rots = [rot_0, rot_1] |
|
slerp = Slerp(key_times, rots) |
|
rot = slerp(ratio) |
|
pose = np.diag([1.0, 1.0, 1.0, 1.0]) |
|
pose = pose.astype(np.float32) |
|
pose[:3, :3] = rot.as_matrix() |
|
pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] |
|
pose = np.linalg.inv(pose) |
|
|
|
c2w = torch.from_numpy(pose).to(device) |
|
rot = torch.from_numpy(pose[:3, :3]).cuda() |
|
trans = torch.from_numpy(pose[:3, 3]).cuda() |
|
rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() |
|
rays_o = trans[None, None, :3].expand(rays_v.shape) |
|
return c2w, rays_o.transpose(0, 1).contiguous().view(-1, 3), rays_v.transpose(0, 1).contiguous().view(-1, 3) |
|
|