|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
import numpy as np |
|
import PIL.Image |
|
from PIL.ImageOps import exif_transpose |
|
import torchvision.transforms as tvf |
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" |
|
import cv2 |
|
import glob |
|
import imageio |
|
import matplotlib.pyplot as plt |
|
|
|
try: |
|
from pillow_heif import register_heif_opener |
|
register_heif_opener() |
|
heif_support_enabled = True |
|
except ImportError: |
|
heif_support_enabled = False |
|
|
|
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
|
ToTensor = tvf.ToTensor() |
|
TAG_FLOAT = 202021.25 |
|
|
|
def depth_read(filename): |
|
""" Read depth data from file, return as numpy array. """ |
|
f = open(filename,'rb') |
|
check = np.fromfile(f,dtype=np.float32,count=1)[0] |
|
assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) |
|
width = np.fromfile(f,dtype=np.int32,count=1)[0] |
|
height = np.fromfile(f,dtype=np.int32,count=1)[0] |
|
size = width*height |
|
assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) |
|
depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width)) |
|
return depth |
|
|
|
def cam_read(filename): |
|
""" Read camera data, return (M,N) tuple. |
|
|
|
M is the intrinsic matrix, N is the extrinsic matrix, so that |
|
|
|
x = M*N*X, |
|
with x being a point in homogeneous image pixel coordinates, X being a |
|
point in homogeneous world coordinates. |
|
""" |
|
f = open(filename,'rb') |
|
check = np.fromfile(f,dtype=np.float32,count=1)[0] |
|
assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) |
|
M = np.fromfile(f,dtype='float64',count=9).reshape((3,3)) |
|
N = np.fromfile(f,dtype='float64',count=12).reshape((3,4)) |
|
return M,N |
|
|
|
def flow_read(filename): |
|
""" Read optical flow from file, return (U,V) tuple. |
|
|
|
Original code by Deqing Sun, adapted from Daniel Scharstein. |
|
""" |
|
f = open(filename,'rb') |
|
check = np.fromfile(f,dtype=np.float32,count=1)[0] |
|
assert check == TAG_FLOAT, ' flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) |
|
width = np.fromfile(f,dtype=np.int32,count=1)[0] |
|
height = np.fromfile(f,dtype=np.int32,count=1)[0] |
|
size = width*height |
|
assert width > 0 and height > 0 and size > 1 and size < 100000000, ' flow_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) |
|
tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2)) |
|
u = tmp[:,np.arange(width)*2] |
|
v = tmp[:,np.arange(width)*2 + 1] |
|
return u,v |
|
|
|
def img_to_arr( img ): |
|
if isinstance(img, str): |
|
img = imread_cv2(img) |
|
return img |
|
|
|
def imread_cv2(path, options=cv2.IMREAD_COLOR): |
|
""" Open an image or a depthmap with opencv-python. |
|
""" |
|
if path.endswith(('.exr', 'EXR')): |
|
options = cv2.IMREAD_ANYDEPTH |
|
img = cv2.imread(path, options) |
|
if img is None: |
|
raise IOError(f'Could not load image={path} with {options=}') |
|
if img.ndim == 3: |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
return img |
|
|
|
|
|
def rgb(ftensor, true_shape=None): |
|
if isinstance(ftensor, list): |
|
return [rgb(x, true_shape=true_shape) for x in ftensor] |
|
if isinstance(ftensor, torch.Tensor): |
|
ftensor = ftensor.detach().cpu().numpy() |
|
if ftensor.ndim == 3 and ftensor.shape[0] == 3: |
|
ftensor = ftensor.transpose(1, 2, 0) |
|
elif ftensor.ndim == 4 and ftensor.shape[1] == 3: |
|
ftensor = ftensor.transpose(0, 2, 3, 1) |
|
if true_shape is not None: |
|
H, W = true_shape |
|
ftensor = ftensor[:H, :W] |
|
if ftensor.dtype == np.uint8: |
|
img = np.float32(ftensor) / 255 |
|
else: |
|
img = (ftensor * 0.5) + 0.5 |
|
return img.clip(min=0, max=1) |
|
|
|
|
|
def _resize_pil_image(img, long_edge_size, nearest=False): |
|
S = max(img.size) |
|
if S > long_edge_size: |
|
interp = PIL.Image.LANCZOS if not nearest else PIL.Image.NEAREST |
|
elif S <= long_edge_size: |
|
interp = PIL.Image.BICUBIC |
|
new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size) |
|
return img.resize(new_size, interp) |
|
|
|
def resize_numpy_image(img, long_edge_size): |
|
""" |
|
Resize the NumPy image to a specified long edge size using OpenCV. |
|
|
|
Args: |
|
img (numpy.ndarray): Input image with shape (H, W, C). |
|
long_edge_size (int): The size of the long edge after resizing. |
|
|
|
Returns: |
|
numpy.ndarray: The resized image. |
|
""" |
|
|
|
h, w = img.shape[:2] |
|
S = max(h, w) |
|
|
|
|
|
if S > long_edge_size: |
|
interp = cv2.INTER_LANCZOS4 |
|
else: |
|
interp = cv2.INTER_CUBIC |
|
|
|
|
|
new_size = (int(round(w * long_edge_size / S)), int(round(h * long_edge_size / S))) |
|
|
|
|
|
resized_img = cv2.resize(img, new_size, interpolation=interp) |
|
|
|
return resized_img |
|
|
|
def crop_center(img, crop_width, crop_height): |
|
""" |
|
Crop the center of the image. |
|
|
|
Args: |
|
img (numpy.ndarray): Input image with shape (H, W) or (H, W, C). |
|
crop_width (int): The width of the cropped area. |
|
crop_height (int): The height of the cropped area. |
|
|
|
Returns: |
|
numpy.ndarray: The cropped image. |
|
""" |
|
h, w = img.shape[:2] |
|
cx, cy = h // 2, w // 2 |
|
x1 = max(cx - crop_height // 2, 0) |
|
x2 = min(cx + crop_height // 2, h) |
|
y1 = max(cy - crop_width // 2, 0) |
|
y2 = min(cy + crop_width // 2, w) |
|
|
|
cropped_img = img[x1:x2, y1:y2] |
|
|
|
return cropped_img |
|
|
|
def crop_img(img, size, pred_depth=None, square_ok=False, nearest=False, crop=True): |
|
W1, H1 = img.size |
|
if size == 224: |
|
|
|
img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1)), nearest=nearest) |
|
if pred_depth is not None: |
|
pred_depth = resize_numpy_image(pred_depth, round(size * max(W1 / H1, H1 / W1))) |
|
else: |
|
|
|
img = _resize_pil_image(img, size, nearest=nearest) |
|
if pred_depth is not None: |
|
pred_depth = resize_numpy_image(pred_depth, size) |
|
W, H = img.size |
|
cx, cy = W//2, H//2 |
|
if size == 224: |
|
half = min(cx, cy) |
|
img = img.crop((cx-half, cy-half, cx+half, cy+half)) |
|
if pred_depth is not None: |
|
pred_depth = crop_center(pred_depth, 2 * half, 2 * half) |
|
else: |
|
halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8 |
|
if not (square_ok) and W == H: |
|
halfh = 3*halfw/4 |
|
if crop: |
|
img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh)) |
|
if pred_depth is not None: |
|
pred_depth = crop_center(pred_depth, 2 * halfw, 2 * halfh) |
|
else: |
|
img = img.resize((2*halfw, 2*halfh), PIL.Image.LANCZOS) |
|
if pred_depth is not None: |
|
pred_depth = cv2.resize(pred_depth, (2*halfw, 2*halfh), interpolation=cv2.INTER_CUBIC) |
|
return img, pred_depth |
|
|
|
def pixel_to_pointcloud(depth_map, focal_length_px): |
|
""" |
|
Convert a depth map to a 3D point cloud. |
|
|
|
Args: |
|
depth_map (numpy.ndarray): The input depth map with shape (H, W), where each value represents the depth at that pixel. |
|
focal_length_px (float): The focal length of the camera in pixels. |
|
|
|
Returns: |
|
numpy.ndarray: The resulting point cloud with shape (H, W, 3), where each point is represented by (X, Y, Z). |
|
""" |
|
height, width = depth_map.shape |
|
cx = width / 2 |
|
cy = height / 2 |
|
|
|
|
|
u = np.arange(width) |
|
v = np.arange(height) |
|
u, v = np.meshgrid(u, v) |
|
|
|
|
|
Z = depth_map |
|
X = (u - cx) * Z / focal_length_px |
|
Y = (v - cy) * Z / focal_length_px |
|
|
|
|
|
point_cloud = np.dstack((X, Y, Z)).astype(np.float32) |
|
point_cloud = normalize_pointcloud(point_cloud) |
|
|
|
|
|
|
|
return point_cloud |
|
|
|
def normalize_pointcloud(point_cloud): |
|
min_vals = np.min(point_cloud, axis=(0, 1)) |
|
max_vals = np.max(point_cloud, axis=(0, 1)) |
|
|
|
normalized_point_cloud = (point_cloud - min_vals) / (max_vals - min_vals) |
|
return normalized_point_cloud |
|
|
|
def load_images(folder_or_list, depth_list, focallength_px_list, size, square_ok=False, verbose=True, dynamic_mask_root=None, crop=True, fps=0, traj_format="sintel", start=0, interval=30, depth_prior_name='depthpro'): |
|
"""Open and convert all images or videos in a list or folder to proper input format for DUSt3R.""" |
|
if isinstance(folder_or_list, str): |
|
if verbose: |
|
print(f'>> Loading images from {folder_or_list}') |
|
|
|
if os.path.isdir(folder_or_list): |
|
root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list)) |
|
else: |
|
root, folder_content = '', [folder_or_list] |
|
|
|
elif isinstance(folder_or_list, list): |
|
if verbose: |
|
print(f'>> Loading a list of {len(folder_or_list)} items') |
|
root, folder_content = '', folder_or_list |
|
|
|
else: |
|
raise ValueError(f'Bad input {folder_or_list=} ({type(folder_or_list)})') |
|
|
|
supported_images_extensions = ['.jpg', '.jpeg', '.png'] |
|
supported_video_extensions = ['.mp4', '.avi', '.mov'] |
|
if heif_support_enabled: |
|
supported_images_extensions += ['.heic', '.heif'] |
|
supported_images_extensions = tuple(supported_images_extensions) |
|
supported_video_extensions = tuple(supported_video_extensions) |
|
|
|
imgs = [] |
|
|
|
|
|
|
|
|
|
for i, path in enumerate(folder_content): |
|
full_path = os.path.join(root, path) |
|
if path.lower().endswith(supported_images_extensions): |
|
|
|
img = exif_transpose(PIL.Image.open(full_path)).convert('RGB') |
|
|
|
pred_depth1 = depth_list[i] |
|
focal_length_px = focallength_px_list[i] |
|
|
|
if len(pred_depth1.shape) == 3: |
|
pred_depth1 = np.squeeze(pred_depth1) |
|
pred_depth = pixel_to_pointcloud(pred_depth1, focal_length_px) |
|
W1, H1 = img.size |
|
img, pred_depth = crop_img(img, size, pred_depth, square_ok=square_ok, crop=crop) |
|
W2, H2 = img.size |
|
if verbose: |
|
print(f' - Adding {path} with resolution {W1}x{H1} --> {W2}x{H2}') |
|
|
|
single_dict = dict( |
|
img=ImgNorm(img)[None], |
|
pred_depth=pred_depth[None,...], |
|
true_shape=np.int32([img.size[::-1]]), |
|
idx=len(imgs), |
|
instance=full_path, |
|
mask=~(ToTensor(img)[None].sum(1) <= 0.01) |
|
) |
|
|
|
if dynamic_mask_root is not None: |
|
dynamic_mask_path = os.path.join(dynamic_mask_root, os.path.basename(path)) |
|
else: |
|
dynamic_mask_path = '' |
|
|
|
|
|
if os.path.exists(dynamic_mask_path): |
|
dynamic_mask = PIL.Image.open(dynamic_mask_path).convert('L') |
|
dynamic_mask, _ = crop_img(dynamic_mask, size, square_ok=square_ok) |
|
|
|
dynamic_mask = ToTensor(dynamic_mask)[None].sum(1) > 0.99 |
|
single_dict['dynamic_mask'] = dynamic_mask |
|
|
|
|
|
|
|
|
|
else: |
|
single_dict['dynamic_mask'] = torch.zeros_like(single_dict['mask']) |
|
|
|
imgs.append(single_dict) |
|
|
|
elif path.lower().endswith(supported_video_extensions): |
|
|
|
if verbose: |
|
print(f'>> Loading video from {full_path}') |
|
cap = cv2.VideoCapture(full_path) |
|
if not cap.isOpened(): |
|
print(f'Error opening video file {full_path}') |
|
continue |
|
|
|
video_fps = cap.get(cv2.CAP_PROP_FPS) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
if video_fps == 0: |
|
print(f'Error: Video FPS is 0 for {full_path}') |
|
cap.release() |
|
continue |
|
if fps > 0: |
|
frame_interval = max(1, int(round(video_fps / fps))) |
|
else: |
|
frame_interval = 1 |
|
frame_indices = list(range(0, total_frames, frame_interval)) |
|
if interval is not None: |
|
frame_indices = frame_indices[:interval] |
|
|
|
if verbose: |
|
print(f' - Video FPS: {video_fps}, Frame Interval: {frame_interval}, Total Frames to Read: {len(frame_indices)}') |
|
|
|
for frame_idx in frame_indices: |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
img = PIL.Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
W1, H1 = img.size |
|
img, _ = crop_img(img, size, square_ok=square_ok, crop=crop) |
|
W2, H2 = img.size |
|
|
|
if verbose: |
|
print(f' - Adding frame {frame_idx} from {path} with resolution {W1}x{H1} --> {W2}x{H2}') |
|
|
|
single_dict = dict( |
|
img=ImgNorm(img)[None], |
|
true_shape=np.int32([img.size[::-1]]), |
|
idx=len(imgs), |
|
instance=f'{full_path}_frame_{frame_idx}', |
|
mask=~(ToTensor(img)[None].sum(1) <= 0.01) |
|
) |
|
|
|
|
|
single_dict['dynamic_mask'] = torch.zeros_like(single_dict['mask']) |
|
|
|
imgs.append(single_dict) |
|
|
|
cap.release() |
|
|
|
else: |
|
continue |
|
|
|
assert imgs, 'No images found at ' + root |
|
if verbose: |
|
print(f' (Found {len(imgs)} images)') |
|
return imgs |
|
|
|
def enlarge_seg_masks(folder, kernel_size=5, prefix="dynamic_mask"): |
|
mask_pathes = glob.glob(f'{folder}/{prefix}_*.png') |
|
for mask_path in mask_pathes: |
|
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) |
|
kernel = np.ones((kernel_size, kernel_size),np.uint8) |
|
enlarged_mask = cv2.dilate(mask, kernel, iterations=1) |
|
cv2.imwrite(mask_path.replace(prefix, 'enlarged_dynamic_mask'), enlarged_mask) |
|
|
|
def show_mask(mask, ax, obj_id=None, random_color=False): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
cmap = plt.get_cmap("tab10") |
|
cmap_idx = 1 if obj_id is None else obj_id |
|
color = np.array([*cmap(cmap_idx)[:3], 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
def get_overlaied_gif(folder, img_format="frame_*.png", mask_format="dynamic_mask_*.png", output_path="_overlaied.gif"): |
|
img_paths = glob.glob(f'{folder}/{img_format}') |
|
mask_paths = glob.glob(f'{folder}/{mask_format}') |
|
assert len(img_paths) == len(mask_paths), f"Number of images and masks should be the same, got {len(img_paths)} images and {len(mask_paths)} masks" |
|
img_paths = sorted(img_paths) |
|
mask_paths = sorted(mask_paths, key=lambda x: int(x.split('_')[-1].split('.')[0])) |
|
frames = [] |
|
for img_path, mask_path in zip(img_paths, mask_paths): |
|
|
|
img = cv2.imread(img_path) |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
|
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) |
|
mask = mask.astype(np.float32) / 255.0 |
|
|
|
fig, ax = plt.subplots(figsize=(img.shape[1]/100, img.shape[0]/100), dpi=100) |
|
ax.imshow(img) |
|
|
|
show_mask(mask, ax) |
|
ax.axis('off') |
|
|
|
fig.canvas.draw() |
|
img_array = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
img_array = img_array.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
frames.append(img_array) |
|
plt.close(fig) |
|
|
|
imageio.mimsave(os.path.join(folder,output_path), frames, fps=10) |
|
|