Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from scipy.interpolate import griddata as interp_grid | |
from tqdm import tqdm | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import torch | |
from packaging import version as pver | |
import torch.nn.functional as F | |
def trajectory_to_camera_poses_v1(traj, intrinsics, sample_n_frames, zc = 1.0): | |
if not isinstance(zc, list): | |
assert isinstance(zc, float) or isinstance(zc, int), 'zc should be a float or int or a list of float or int' | |
zc = [zc] * traj.shape[0] | |
zc = np.array(zc, dtype=traj.dtype) | |
xc = (traj[:, 0] - intrinsics[0, 2]) * zc / intrinsics[0, 0] | |
yc = (traj[:, 1] - intrinsics[0, 3]) * zc / intrinsics[0, 1] | |
first_frame_w2c = np.array([ | |
[1, 0, 0, 0], | |
[0, 1, 0, 0], | |
[0, 0, 1, 0], | |
[0, 0, 0, 1] | |
], dtype=np.float32) | |
xw = xc[0] | |
yw = yc[0] | |
zw = zc[0] | |
# zw = 0 | |
# print(f'zw: {zw}') | |
Tx = xc - xw | |
Ty = yc - yw | |
Tz = zc - zw | |
traj_w2c = [first_frame_w2c] | |
for i in range(1, sample_n_frames): | |
w2c_mat = np.array([ | |
[1, 0, 0, Tx[i]], | |
[0, 1, 0, Ty[i]], | |
[0, 0, 1, Tz[i]], | |
[0, 0, 0, 1] | |
], dtype=first_frame_w2c.dtype) | |
traj_w2c.append(w2c_mat) | |
traj_w2c = np.stack(traj_w2c, axis=0) | |
return traj_w2c # [n_frame, 4, 4] | |
def Unprojected(image_curr: np.array, | |
depth_curr: np.array, | |
RTs: np.array, | |
H: int = 320, W: int = 576, | |
K: np.array = None, | |
dtype: np.dtype = np.float32): | |
''' | |
image_curr: [H, W, c], float, 0-1 | |
depth_curr: [H, W], float32, in meters | |
RTs: [num_frames, 3, 4], the camera poses, w2c | |
''' | |
x, y = np.meshgrid(np.arange(W, dtype=dtype), np.arange(H, dtype=dtype), indexing='xy') # pixels | |
# ceter_depth = np.mean(depth_curr[cam.H//2-10:cam.H//2+10, cam.W//2-10:cam.W//2+10]) | |
RTs = RTs.astype(dtype) | |
depth_curr = depth_curr.astype(dtype) | |
image_curr = image_curr.reshape(H*W, -1).astype(dtype) # [0, 1] | |
R0, T0 = RTs[0, :, :3], RTs[0, :, 3:4] | |
# new_pts_coord_world2 = image_curr | |
pts_coord_cam = np.matmul(np.linalg.inv(K), np.stack((x*depth_curr, y*depth_curr, 1*depth_curr), axis=0).reshape(3,-1)) | |
new_pts_coord_world2 = (np.linalg.inv(R0).dot(pts_coord_cam) - np.linalg.inv(R0).dot(T0)) ## new_pts_coord_world2 | |
new_pts_colors2 = image_curr ## new_pts_colors2 | |
pts_coord_world, pts_colors = new_pts_coord_world2.copy(), new_pts_colors2.copy() | |
images = [] | |
for i in tqdm(range(1, RTs.shape[0])): | |
R, T = RTs[i, :, :3], RTs[i, :, 3:4] | |
### Transform world to pixel | |
pts_coord_cam2 = R.dot(pts_coord_world) + T ### Same with c2w*world_coord (in homogeneous space) | |
pixel_coord_cam2 = np.matmul(K, pts_coord_cam2) #.reshape(3,H,W).transpose(1,2,0).astype(np.float32) | |
valid_idx = np.where(np.logical_and.reduce((pixel_coord_cam2[2]>0, | |
pixel_coord_cam2[0]/pixel_coord_cam2[2]>=0, | |
pixel_coord_cam2[0]/pixel_coord_cam2[2]<=W-1, | |
pixel_coord_cam2[1]/pixel_coord_cam2[2]>=0, | |
pixel_coord_cam2[1]/pixel_coord_cam2[2]<=H-1)))[0] | |
pixel_coord_cam2 = pixel_coord_cam2[:2, valid_idx]/pixel_coord_cam2[-1:, valid_idx] | |
# round_coord_cam2 = np.round(pixel_coord_cam2).astype(np.int32) | |
x, y = np.meshgrid(np.arange(W, dtype=dtype), np.arange(H, dtype=dtype), indexing='xy') | |
grid = np.stack((x,y), axis=-1).reshape(-1,2) | |
image2 = interp_grid(pixel_coord_cam2.transpose(1,0), pts_colors[valid_idx], grid, method='linear', fill_value=0).reshape(H,W,-1) | |
images.append(image2) | |
print(f'Total {len(images)} images, each image shape: {images[0].shape}') | |
return images | |
class Camera(object): | |
def __init__(self, entry): | |
fx, fy, cx, cy = entry[1:5] | |
self.fx = fx | |
self.fy = fy | |
self.cx = cx | |
self.cy = cy | |
w2c_mat = np.array(entry[7:]).reshape(3, 4) | |
w2c_mat_4x4 = np.eye(4) | |
w2c_mat_4x4[:3, :] = w2c_mat | |
self.w2c_mat = w2c_mat_4x4 | |
self.c2w_mat = np.linalg.inv(w2c_mat_4x4) | |
def get_relative_pose(cam_params, zero_t_first_frame): | |
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] | |
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] | |
source_cam_c2w = abs_c2ws[0] | |
if zero_t_first_frame: | |
cam_to_origin = 0 | |
else: | |
cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3]) | |
target_cam_c2w = np.array([ | |
[1, 0, 0, 0], | |
[0, 1, 0, -cam_to_origin], | |
[0, 0, 1, 0], | |
[0, 0, 0, 1] | |
]) | |
abs2rel = target_cam_c2w @ abs_w2cs[0] | |
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] | |
ret_poses = np.array(ret_poses, dtype=np.float32) | |
return ret_poses | |
def custom_meshgrid(*args): | |
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid | |
if pver.parse(torch.__version__) < pver.parse('1.10'): | |
return torch.meshgrid(*args) | |
else: | |
return torch.meshgrid(*args, indexing='ij') | |
def ray_condition(K, c2w, H, W, device, flip_flag=None): | |
# c2w: B, V, 4, 4 | |
# K: B, V, 4 | |
B, V = K.shape[:2] | |
j, i = custom_meshgrid( | |
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), | |
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), | |
) | |
i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] | |
j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] | |
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0 | |
if n_flip > 0: | |
j_flip, i_flip = custom_meshgrid( | |
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), | |
torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype) | |
) | |
i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 | |
j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 | |
i[:, flip_flag, ...] = i_flip | |
j[:, flip_flag, ...] = j_flip | |
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 | |
zs = torch.ones_like(i) # [B, V, HxW] | |
xs = (i - cx) / fx * zs | |
ys = (j - cy) / fy * zs | |
zs = zs.expand_as(ys) | |
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 | |
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 | |
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3 | |
rays_o = c2w[..., :3, 3] # B, V, 3 | |
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3 | |
# c2w @ dirctions | |
rays_dxo = torch.cross(rays_o, rays_d) # B, V, HW, 3 | |
plucker = torch.cat([rays_dxo, rays_d], dim=-1) | |
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 | |
# plucker = plucker.permute(0, 1, 4, 2, 3) | |
return plucker, rays_o, rays_d | |
def RT2Plucker(RT, num_frames, sample_size, fx, fy, cx, cy): | |
''' | |
RT: [num_frames, 3, 4] | |
''' | |
cam_params = [] | |
for i in range(num_frames): | |
cam_params.append(Camera([0, fx, fy, cx, cy, 0, 0, RT[i].reshape(-1)])) | |
print(cam_params[0].c2w_mat.shape) | |
intrinsics = np.asarray([[cam_param.fx * sample_size[1], | |
cam_param.fy * sample_size[0], | |
cam_param.cx * sample_size[1], | |
cam_param.cy * sample_size[0]] | |
for cam_param in cam_params], dtype=np.float32) | |
intrinsics = torch.as_tensor(intrinsics)[None] | |
print(intrinsics.shape) | |
relative_pose = True | |
zero_t_first_frame = True | |
use_flip = False | |
if relative_pose: | |
c2w_poses = get_relative_pose(cam_params, zero_t_first_frame) | |
else: | |
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32) | |
c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4] | |
flip_flag = torch.zeros(num_frames, dtype=torch.bool, device=c2w.device) | |
plucker_embedding, rays_o, rays_d = ray_condition(intrinsics, c2w, sample_size[0], sample_size[1], device='cpu', | |
flip_flag=flip_flag) | |
plucker_embedding = plucker_embedding[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W | |
plucker_embedding = plucker_embedding.permute(1, 0, 2, 3).contiguous() # 6, V, H, W | |
return plucker_embedding, rays_o, rays_d | |
def visualized_trajectories(images, trajectories, save_path, save_each_frame=False): | |
''' | |
images: [n_frame, H, W, 3], numpy, 0-255 | |
trajectories: [[n_frame, 2]], list[numpy], x, y | |
save_path: str, end with .gif | |
''' | |
pil_image = [] | |
H, W = images.shape[1], images.shape[2] | |
n_frame = images.shape[0] | |
for i in range(n_frame): | |
image = images[i].astype(np.uint8) | |
image = cv2.UMat(image) | |
# print(f'image: {image.shape} {image.dtype} {type(image)}') | |
# | |
for traj in trajectories: | |
line_data = traj[:i+1] | |
if len(line_data) == 1: | |
y = int(round(line_data[0][1])) | |
x = int(round(line_data[0][0])) | |
if y >= H: | |
y = H - 1 | |
if line_data[0][0] >= W: | |
x = W - 1 | |
# image[y, x] = [255, 0, 0] | |
cv2.circle(image, (x, y), 1, (0, 255, 0), 8) | |
else: | |
for j in range(1, len(line_data)): | |
x0, y0 = int(round(line_data[j-1][0])), int(round(line_data[j-1][1])) | |
x1, y1 = int(round(line_data[j][0])), int(round(line_data[j][1])) | |
if y0 >= H: | |
y0 = H - 1 | |
if y1 >= H: | |
y1 = H - 1 | |
if x0 >= W: | |
x0 = W - 1 | |
if x1 >= W: | |
x1 = W - 1 | |
if x0 != x1 or y0 != y1: | |
if j == len(line_data) - 1: | |
line_length = np.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2) | |
arrow_head_length = 10 | |
tip_length = arrow_head_length / line_length | |
cv2.arrowedLine(image, (x0, y0), (x1, y1), (255, 0, 0), 6, tipLength=tip_length) | |
else: | |
cv2.line(image, (x0, y0), (x1, y1), (255, 0, 0), 6) | |
cv2.circle(image, (x, y), 1, (0, 255, 0), 8) | |
# cv2.circle(image, (x1, y1), 1, (0, 0, 255), 5) | |
image = cv2.UMat.get(image) | |
pil_image.append(Image.fromarray(image)) | |
pil_image[0].save(save_path, save_all=True, append_images=pil_image[1:], loop=0, duration=100) | |
# save each frame | |
if save_each_frame: | |
img_save_root = save_path.replace('.gif', '') | |
os.makedirs(img_save_root, exist_ok=True) | |
for i, img in enumerate(pil_image): | |
img.save(os.path.join(img_save_root, f'{i:05d}.png')) | |
def roll_with_ignore_multidim(arr, shifts): | |
result = np.copy(arr) | |
for axis, shift in enumerate(shifts): | |
result = roll_with_ignore(result, shift, axis) | |
return result | |
def roll_with_ignore(arr, shift, axis): | |
result = np.zeros_like(arr) | |
if shift > 0: | |
result[tuple(slice(shift, None) if i == axis else slice(None) for i in range(arr.ndim))] = \ | |
arr[tuple(slice(None, -shift) if i == axis else slice(None) for i in range(arr.ndim))] | |
elif shift < 0: | |
result[tuple(slice(None, shift) if i == axis else slice(None) for i in range(arr.ndim))] = \ | |
arr[tuple(slice(-shift, None) if i == axis else slice(None) for i in range(arr.ndim))] | |
else: | |
result = arr | |
return result | |
def dilate_mask_pytorch(mask, kernel_size=2): | |
''' | |
mask: torch.Tensor, shape [b, c, h, w] | |
kernel_size: int | |
''' | |
# Define a smaller kernel for the dilation | |
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=mask.dtype, device=mask.device) | |
# Perform the dilation operation | |
dilated_mask = F.conv2d(mask, kernel, padding=kernel_size//2) | |
# Ensure the output is still a binary mask (0 and 1) | |
dilated_mask = (dilated_mask > 0).to(mask.dtype).to(mask.device) | |
return dilated_mask | |
def smooth_mask(mask, kernel_size=3): | |
''' | |
mask: torch.Tensor, shape [b, c, h, w] | |
kernel_size: int | |
''' | |
# Define a Gaussian kernel for smoothing | |
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=mask.dtype, device=mask.device) / (kernel_size * kernel_size) | |
# Perform the smoothing operation | |
smoothed_mask = F.conv2d(mask, kernel, padding=kernel_size//2) | |
# Ensure the output is still a binary mask (0 and 1) | |
smoothed_mask = (smoothed_mask > 0.5).to(mask.dtype).to(mask.device) | |
return smoothed_mask | |